Commit b8428a7f authored by peng xu's avatar peng xu
Browse files

Merge branch 'main' into beam_search

parents e5034150 3f4e71df
...@@ -291,7 +291,7 @@ class PretrainedBertModel(MegatronModule): ...@@ -291,7 +291,7 @@ class PretrainedBertModel(MegatronModule):
pool_mask = (input_ids == self.pad_id).unsqueeze(2) pool_mask = (input_ids == self.pad_id).unsqueeze(2)
# Taking the representation of the [CLS] token of BERT # Taking the representation of the [CLS] token of BERT
pooled_output = lm_output[:, 0, :] pooled_output = lm_output[0, :, :]
# Converting to float16 dtype # Converting to float16 dtype
pooled_output = pooled_output.to(lm_output.dtype) pooled_output = pooled_output.to(lm_output.dtype)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
import math
import torch import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
...@@ -24,18 +25,17 @@ from megatron import mpu ...@@ -24,18 +25,17 @@ from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
class MemoryBuffer: class MemoryBuffer:
def __init__(self, numel, dtype): def __init__(self, numel, numel_padded, dtype):
self.numel = numel self.numel = numel
self.numel_padded = numel_padded
self.dtype = dtype self.dtype = dtype
self.data = torch.zeros(self.numel, self.data = torch.zeros(self.numel_padded,
dtype=self.dtype, dtype=self.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False) requires_grad=False)
def zero(self): def zero(self):
"""Reset the buffer to zero.""" """Reset the buffer to zero."""
self.data.zero_() self.data.zero_()
...@@ -121,8 +121,11 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -121,8 +121,11 @@ class DistributedDataParallel(DistributedDataParallelBase):
# the case we use continuous buffers. # the case we use continuous buffers.
# =================================== # ===================================
self._grad_buffers = None self._grad_buffers = None
self._grad_buffer_param_index_map = None
if self.use_contiguous_buffers: if self.use_contiguous_buffers:
self._grad_buffers = {} self._grad_buffers = {}
self._grad_buffer_param_index_map = {}
data_parallel_world_size = mpu.get_data_parallel_world_size()
# Simple function to define buffer type. # Simple function to define buffer type.
def _get_buffer_type(param): def _get_buffer_type(param):
...@@ -139,7 +142,18 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -139,7 +142,18 @@ class DistributedDataParallel(DistributedDataParallelBase):
# Allocate the buffer. # Allocate the buffer.
for dtype, num_elements in type_num_elements.items(): for dtype, num_elements in type_num_elements.items():
self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype)
# If using distributed optimizer, pad memory buffer to be
# multiple of data_parallel_world_size. (This padding is done
# due to a constraint with the reduce_scatter op, which requires
# all tensors have equal size. See: optimizer.py.)
num_elements_padded = data_parallel_world_size * \
int(math.ceil(num_elements / data_parallel_world_size))
# Allocate grad buffer.
self._grad_buffers[dtype] = MemoryBuffer(num_elements,
num_elements_padded,
dtype)
# Assume the back prop order is reverse the params order, # Assume the back prop order is reverse the params order,
# store the start index for the gradients. # store the start index for the gradients.
...@@ -149,6 +163,12 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -149,6 +163,12 @@ class DistributedDataParallel(DistributedDataParallelBase):
type_num_elements[dtype] -= param.data.nelement() type_num_elements[dtype] -= param.data.nelement()
param.main_grad = self._grad_buffers[dtype].get( param.main_grad = self._grad_buffers[dtype].get(
param.data.shape, type_num_elements[dtype]) param.data.shape, type_num_elements[dtype])
if dtype not in self._grad_buffer_param_index_map:
self._grad_buffer_param_index_map[dtype] = {}
self._grad_buffer_param_index_map[dtype][param] = (
type_num_elements[dtype],
type_num_elements[dtype] + param.data.nelement(),
)
# Backward hook. # Backward hook.
# Accumalation function for the gradients. We need # Accumalation function for the gradients. We need
...@@ -170,7 +190,8 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -170,7 +190,8 @@ class DistributedDataParallel(DistributedDataParallelBase):
# Hook used for back-prop. # Hook used for back-prop.
def param_hook(*unused): def param_hook(*unused):
# Add the gradient to the buffer. # Add the gradient to the buffer.
if param.grad.data is not None: if param.grad is not None:
# The gradient function of linear layers is fused with GEMMs
param.main_grad.add_(param.grad.data) param.main_grad.add_(param.grad.data)
# Now we can deallocate grad memory. # Now we can deallocate grad memory.
param.grad = None param.grad = None
......
...@@ -23,6 +23,8 @@ from torch.nn.parameter import Parameter ...@@ -23,6 +23,8 @@ from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
import importlib import importlib
from megatron.mpu import make_viewless_tensor
try: try:
from apex.contrib.layer_norm.layer_norm import FastLayerNormFN from apex.contrib.layer_norm.layer_norm import FastLayerNormFN
HAVE_PERSIST_LAYER_NORM = True HAVE_PERSIST_LAYER_NORM = True
...@@ -67,7 +69,9 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): ...@@ -67,7 +69,9 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
class MixedFusedLayerNorm(torch.nn.Module): class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5, no_persist_layer_norm=True): def __init__(self, normalized_shape, eps=1e-5,
no_persist_layer_norm=True,
sequence_parallel=False):
super(MixedFusedLayerNorm, self).__init__() super(MixedFusedLayerNorm, self).__init__()
global fused_mix_prec_layer_norm_cuda global fused_mix_prec_layer_norm_cuda
...@@ -92,6 +96,11 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -92,6 +96,11 @@ class MixedFusedLayerNorm(torch.nn.Module):
self.bias = Parameter(torch.Tensor(*normalized_shape)) self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters() self.reset_parameters()
self.no_persist_layer_norm = no_persist_layer_norm self.no_persist_layer_norm = no_persist_layer_norm
self.sequence_parallel = sequence_parallel
# set sequence parallelism flag on weight and bias parameters
setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
def reset_parameters(self): def reset_parameters(self):
...@@ -106,6 +115,15 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -106,6 +115,15 @@ class MixedFusedLayerNorm(torch.nn.Module):
return FusedLayerNormAffineFunction.apply( return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps) input, self.weight, self.bias, self.normalized_shape, self.eps)
else: else:
return FastLayerNormFN.apply( output = FastLayerNormFN.apply(
input, self.weight, self.bias, self.eps) input, self.weight, self.bias, self.eps)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
output = make_viewless_tensor(inp = output,
requires_grad = input.requires_grad,
keep_graph = True)
return output
...@@ -32,20 +32,26 @@ def post_language_model_processing(lm_output, labels, logit_weights, ...@@ -32,20 +32,26 @@ def post_language_model_processing(lm_output, labels, logit_weights,
parallel_output, parallel_output,
fp16_lm_cross_entropy): fp16_lm_cross_entropy):
# Output. # Output. Format [s b h]
output = parallel_lm_logits( output = parallel_lm_logits(
lm_output, lm_output,
logit_weights, logit_weights,
parallel_output) parallel_output)
if labels is None: if labels is None:
return output # [s b h] => [b s h]
return output.transpose(0,1).contiguous()
else: else:
# [b s] => [s b]
labels = labels.transpose(0,1).contiguous()
if fp16_lm_cross_entropy: if fp16_lm_cross_entropy:
assert output.dtype == torch.half assert output.dtype == torch.half
loss = mpu.vocab_parallel_cross_entropy(output, labels) loss = mpu.vocab_parallel_cross_entropy(output, labels)
else: else:
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
# [s b] => [b, s]
loss = loss.transpose(0,1).contiguous()
return loss return loss
......
...@@ -26,17 +26,29 @@ from megatron.model.transformer import ParallelTransformer ...@@ -26,17 +26,29 @@ from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal, scaled_init_method_normal from megatron.model.utils import init_method_normal, scaled_init_method_normal
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
bias=None): bias=None):
"""LM logits using word embedding weights.""" """LM logits using word embedding weights."""
args = get_args()
# Parallel logits. # Parallel logits.
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) if args.async_tensor_model_parallel_allreduce or\
# Matrix multiply. args.sequence_parallel:
if bias is None: input_parallel = input_
logits_parallel = F.linear(input_parallel, word_embeddings_weight) model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
model_parallel and not args.sequence_parallel
else: else:
logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias) input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False
# Matrix multiply.
logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, word_embeddings_weight, bias,
args.gradient_accumulation_fusion,
async_grad_allreduce, args.sequence_parallel)
# Gather if needed. # Gather if needed.
if parallel_output: if parallel_output:
return logits_parallel return logits_parallel
...@@ -92,12 +104,23 @@ class Pooler(MegatronModule): ...@@ -92,12 +104,23 @@ class Pooler(MegatronModule):
def __init__(self, hidden_size, init_method): def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__() super(Pooler, self).__init__()
args = get_args()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.sequence_parallel = args.sequence_parallel
def forward(self, hidden_states, sequence_index=0): def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h] # hidden_states: [s, b, h]
# sequence_index: index of the token to pool. # sequence_index: index of the token to pool.
pooled = hidden_states[:, sequence_index, :]
# gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes
if self.sequence_parallel:
hidden_states = mpu.gather_from_sequence_parallel_region(
hidden_states,
tensor_parallel_output_grad=False)
pooled = hidden_states[sequence_index, :, :]
pooled = self.dense(pooled) pooled = self.dense(pooled)
pooled = torch.tanh(pooled) pooled = torch.tanh(pooled)
return pooled return pooled
...@@ -158,6 +181,8 @@ class Embedding(MegatronModule): ...@@ -158,6 +181,8 @@ class Embedding(MegatronModule):
else: else:
self.tokentype_embeddings = None self.tokentype_embeddings = None
self.fp32_residual_connection = args.fp32_residual_connection
self.sequence_parallel = args.sequence_parallel
# Embeddings dropout # Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
...@@ -199,8 +224,20 @@ class Embedding(MegatronModule): ...@@ -199,8 +224,20 @@ class Embedding(MegatronModule):
else: else:
assert self.tokentype_embeddings is None assert self.tokentype_embeddings is None
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings = embeddings.transpose(0, 1).contiguous()
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
embeddings = embeddings.float()
# Dropout. # Dropout.
embeddings = self.embedding_dropout(embeddings) if self.sequence_parallel:
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)
with mpu.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
embeddings = self.embedding_dropout(embeddings)
return embeddings return embeddings
......
...@@ -152,19 +152,24 @@ class T5Model(MegatronModule): ...@@ -152,19 +152,24 @@ class T5Model(MegatronModule):
if self.post_process and self.add_decoder: if self.post_process and self.add_decoder:
decoder_output, encoder_output = lm_output decoder_output, encoder_output = lm_output
# Output. # Output. [s, b, h]
lm_logits = self.lm_head(decoder_output, lm_logits = self.lm_head(decoder_output,
self.word_embeddings_weight()) self.word_embeddings_weight())
if lm_labels is None: if lm_labels is None:
return lm_logits # [s b h] => [b s h]
return lm_logits.transpose(0,1).contiguous()
else: else:
# [b s] => [s b]
lm_labels = lm_labels.transpose(0,1).contiguous()
if self.fp16_lm_cross_entropy: if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else: else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels) lm_labels)
# [s b] => [b s]
lm_loss = lm_loss.transpose(0,1).contiguous()
return lm_loss return lm_loss
elif self.add_decoder and not self.add_encoder: elif self.add_decoder and not self.add_encoder:
decoder_output, encoder_output = lm_output decoder_output, encoder_output = lm_output
......
...@@ -15,10 +15,11 @@ ...@@ -15,10 +15,11 @@
"""Transformer.""" """Transformer."""
import math import math
from contextlib import nullcontext
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_timers, get_args, get_global_memory_buffer
from megatron import mpu from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
...@@ -27,6 +28,7 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax ...@@ -27,6 +28,7 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
""" We use the following notation throughout this file: """ We use the following notation throughout this file:
h: hidden size h: hidden size
n: number of attention heads n: number of attention heads
...@@ -42,7 +44,6 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu ...@@ -42,7 +44,6 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
hyperparameters: transformer hyperparameters hyperparameters: transformer hyperparameters
""" """
class DropPath(MegatronModule): class DropPath(MegatronModule):
"""Drop paths (Stochastic Depth) per sample """Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks). (when applied in main path of residual blocks).
...@@ -116,11 +117,196 @@ class ParallelMLP(MegatronModule): ...@@ -116,11 +117,196 @@ class ParallelMLP(MegatronModule):
output, output_bias = self.dense_4h_to_h(intermediate_parallel) output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias return output, output_bias
class SwitchMLP(MegatronModule):
"""
Routes input to one of N MLP "experts"
"""
def __init__(self, init_method, output_layer_init_method):
super(SwitchMLP, self).__init__()
args = get_args()
self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
self.experts = torch.nn.ModuleList()
for i in range(args.num_experts):
self.experts.append(ParallelMLP(init_method, output_layer_init_method))
def forward(self, hidden_states):
# hidden_states: [s, b, h]
s = hidden_states.size(0)
b = hidden_states.size(1)
h = hidden_states.size(2)
route = self.router(hidden_states)
route = torch.nn.functional.softmax(route, dim=2)
max_prob, max_ind = torch.max(route, dim=2)
max_prob = torch.unsqueeze(max_prob, 2) # [s b 1]
# TODO (rprenger) TODO this could be made easier to read
# Converting [s, b, h] to [s*b, h].
# Each vector could be routed differently
hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1]
max_ind = max_ind.view(-1) # [s*b]
output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states)
#TODO (rprenger) This does each expert in serial, but it could be parallelized
for expert_num, expert in enumerate(self.experts):
local_indices = (max_ind == expert_num).nonzero()
hidden = hidden_states[local_indices,:]
output, output_bias = expert(hidden)
output_bias = output_bias.expand_as(output)
output_total[local_indices,:] = output
output_bias_total[local_indices,:] = output_bias
output_total = output_total*max_prob
output_bias_total = output_bias_total*max_prob
output_total = output_total.view(s, b, h)
output_bias_total = output_bias_total.view(s, b, h)
return output_total, output_bias_total
class CoreAttention(MegatronModule):
def __init__(self, layer_number,
attn_mask_type=AttnMaskType.padding):
super(CoreAttention, self).__init__()
args = get_args()
self.fp16 = args.fp16
self.bf16 = args.bf16
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number)
self.attn_mask_type = attn_mask_type
self.sequence_parallel = args.sequence_parallel
projection_size = args.kv_channels * args.num_attention_heads
# Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = mpu.divide(
projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide(
args.num_attention_heads, world_size)
coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
def forward(self, query_layer, key_layer,
value_layer, attention_mask):
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer = get_global_memory_buffer().get_tensor(
(output_size[0]*output_size[1], output_size[2], output_size[3]),
query_layer.dtype, "mpu")
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if not self.sequence_parallel:
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
else:
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class ParallelAttention(MegatronModule): class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class. """Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h] Self-attention layer takes input with size [s, b, h]
and returns output of the same size. and returns output of the same size.
""" """
...@@ -130,13 +316,6 @@ class ParallelAttention(MegatronModule): ...@@ -130,13 +316,6 @@ class ParallelAttention(MegatronModule):
attn_mask_type=AttnMaskType.padding): attn_mask_type=AttnMaskType.padding):
super(ParallelAttention, self).__init__() super(ParallelAttention, self).__init__()
args = get_args() args = get_args()
self.fp16 = args.fp16
self.bf16 = args.bf16
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
self.layer_number = max(1, layer_number) self.layer_number = max(1, layer_number)
self.attention_type = attention_type self.attention_type = attention_type
self.attn_mask_type = attn_mask_type self.attn_mask_type = attn_mask_type
...@@ -146,8 +325,6 @@ class ParallelAttention(MegatronModule): ...@@ -146,8 +325,6 @@ class ParallelAttention(MegatronModule):
# Per attention head and per partition values. # Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size() world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_partition = mpu.divide(projection_size,
world_size)
self.hidden_size_per_attention_head = mpu.divide( self.hidden_size_per_attention_head = mpu.divide(
projection_size, args.num_attention_heads) projection_size, args.num_attention_heads)
self.num_attention_heads_per_partition = mpu.divide( self.num_attention_heads_per_partition = mpu.divide(
...@@ -174,24 +351,9 @@ class ParallelAttention(MegatronModule): ...@@ -174,24 +351,9 @@ class ParallelAttention(MegatronModule):
gather_output=False, gather_output=False,
init_method=init_method) init_method=init_method)
coeff = None self.core_attention = CoreAttention(self.layer_number,
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) self.attn_mask_type)
if self.apply_query_key_layer_scaling: self.checkpoint_core_attention = args.recompute_granularity == 'selective'
coeff = self.layer_number
self.norm_factor *= coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
self.fp16, self.bf16,
self.attn_mask_type,
args.masked_softmax_fusion,
attention_mask_func,
self.attention_softmax_in_fp32,
coeff)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
# Output. # Output.
self.dense = mpu.RowParallelLinear( self.dense = mpu.RowParallelLinear(
...@@ -201,6 +363,23 @@ class ParallelAttention(MegatronModule): ...@@ -201,6 +363,23 @@ class ParallelAttention(MegatronModule):
init_method=output_layer_init_method, init_method=output_layer_init_method,
skip_bias_add=True) skip_bias_add=True)
def _checkpointed_attention_forward(self, query_layer, key_layer,
value_layer, attention_mask):
"""Forward method with activation checkpointing."""
def custom_forward(*inputs):
query_layer = inputs[0]
key_layer = inputs[1]
value_layer = inputs[2]
attention_mask = inputs[3]
output_ = self.core_attention(query_layer, key_layer,
value_layer, attention_mask)
return output_
hidden_states = mpu.checkpoint(
custom_forward,
False, query_layer, key_layer, value_layer, attention_mask)
return hidden_states
def _allocate_memory(self, inference_max_sequence_len, batch_size): def _allocate_memory(self, inference_max_sequence_len, batch_size):
return torch.empty( return torch.empty(
...@@ -210,13 +389,11 @@ class ParallelAttention(MegatronModule): ...@@ -210,13 +389,11 @@ class ParallelAttention(MegatronModule):
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
dtype=self.params_dtype, dtype=self.params_dtype,
device=torch.cuda.current_device()) device=torch.cuda.current_device())
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, inference_params=None): encoder_output=None, inference_params=None):
# hidden_states: [sq, b, h] # hidden_states: [sq, b, h]
# ================================================= # =================================================
# Pre-allocate memory for key-values for inference. # Pre-allocate memory for key-values for inference.
# ================================================= # =================================================
...@@ -234,7 +411,6 @@ class ParallelAttention(MegatronModule): ...@@ -234,7 +411,6 @@ class ParallelAttention(MegatronModule):
inference_key_memory, inference_value_memory = \ inference_key_memory, inference_value_memory = \
inference_params.key_value_memory_dict[self.layer_number] inference_params.key_value_memory_dict[self.layer_number]
# ===================== # =====================
# Query, Key, and Value # Query, Key, and Value
# ===================== # =====================
...@@ -275,7 +451,6 @@ class ParallelAttention(MegatronModule): ...@@ -275,7 +451,6 @@ class ParallelAttention(MegatronModule):
self.hidden_size_per_attention_head) self.hidden_size_per_attention_head)
query_layer = query_layer.view(*new_tensor_shape) query_layer = query_layer.view(*new_tensor_shape)
# ================================== # ==================================
# Adjust key and value for inference # Adjust key and value for inference
# ================================== # ==================================
...@@ -297,90 +472,16 @@ class ParallelAttention(MegatronModule): ...@@ -297,90 +472,16 @@ class ParallelAttention(MegatronModule):
value_layer = inference_value_memory[ value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...] :sequence_end, batch_start:batch_end, ...]
# ==================================
# core attention computation
# ==================================
# =================================== if self.checkpoint_core_attention:
# Raw attention scores. [b, np, s, s] context_layer = self._checkpointed_attention_forward(
# =================================== query_layer, key_layer, value_layer, attention_mask)
else:
# [b, np, sq, sk] context_layer = self.core_attention(
output_size = (query_layer.size(1), query_layer, key_layer, value_layer, attention_mask)
query_layer.size(2),
query_layer.size(0),
key_layer.size(0))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(output_size[3],
output_size[0] * output_size[1], -1)
# preallocting result tensor: [b * np, sq, sk]
matmul_result = torch.empty(
output_size[0]*output_size[1],
output_size[2],
output_size[3],
dtype=query_layer.dtype,
device=torch.cuda.current_device())
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
matmul_result,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
# change view to [b, np, sq, sk]
attention_scores = matmul_result.view(*output_size)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs = self.scale_mask_softmax(attention_scores,
attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with mpu.get_cuda_rng_tracker().fork():
attention_probs = self.attention_dropout(attention_probs)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# change view [sk, b * np, hn]
value_layer = value_layer.view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
attention_probs = attention_probs.view(output_size[0] * output_size[1],
output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# change view [b, np, sq, hn]
context_layer = context_layer.view(*output_size)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
# ================= # =================
# Output. [sq, b, h] # Output. [sq, b, h]
...@@ -423,7 +524,7 @@ def bias_dropout_add_fused_inference(x: torch.Tensor, ...@@ -423,7 +524,7 @@ def bias_dropout_add_fused_inference(x: torch.Tensor,
class ParallelTransformerLayer(MegatronModule): class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer. """A single transformer layer.
Transformer layer takes input with size [b, s, h] and returns an Transformer layer takes input with size [s, b, h] and returns an
output of the same size. output of the same size.
""" """
...@@ -447,7 +548,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -447,7 +548,8 @@ class ParallelTransformerLayer(MegatronModule):
self.input_layernorm = LayerNorm( self.input_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm) no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
# Self attention. # Self attention.
self.self_attention = ParallelAttention( self.self_attention = ParallelAttention(
...@@ -464,7 +566,8 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -464,7 +566,8 @@ class ParallelTransformerLayer(MegatronModule):
self.post_attention_layernorm = LayerNorm( self.post_attention_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm) no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
if self.layer_type == LayerType.decoder: if self.layer_type == LayerType.decoder:
self.inter_attention = ParallelAttention( self.inter_attention = ParallelAttention(
...@@ -476,16 +579,26 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -476,16 +579,26 @@ class ParallelTransformerLayer(MegatronModule):
self.post_inter_attention_layernorm = LayerNorm( self.post_inter_attention_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm) no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
# MLP # MLP
self.mlp = ParallelMLP(init_method, if args.num_experts is not None:
output_layer_init_method) self.mlp = SwitchMLP(init_method, output_layer_init_method)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)
self.bias_dropout_add_exec_handler = \
nullcontext if use_nvfuser else torch.enable_grad
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None, enc_dec_attn_mask=None,
inference_params=None): inference_params=None):
# hidden_states: [b, s, h] # hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
...@@ -515,8 +628,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -515,8 +628,7 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
bias_dropout_add_func = get_bias_dropout_add(self.training) bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization. with self.bias_dropout_add_exec_handler():
with torch.enable_grad():
layernorm_input = bias_dropout_add_func( layernorm_input = bias_dropout_add_func(
attention_output, attention_output,
attention_bias.expand_as(residual), attention_bias.expand_as(residual),
...@@ -542,8 +654,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -542,8 +654,7 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
residual = layernorm_input residual = layernorm_input
# re-enable torch grad to enable fused optimization. with self.bias_dropout_add_exec_handler():
with torch.enable_grad():
layernorm_input = bias_dropout_add_func( layernorm_input = bias_dropout_add_func(
attention_output, attention_output,
attention_bias.expand_as(residual), attention_bias.expand_as(residual),
...@@ -563,13 +674,23 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -563,13 +674,23 @@ class ParallelTransformerLayer(MegatronModule):
residual = layernorm_input residual = layernorm_input
if self.drop_path is None: if self.drop_path is None:
# re-enable torch grad to enable fused optimization. with self.bias_dropout_add_exec_handler():
with torch.enable_grad():
output = bias_dropout_add_func( output = bias_dropout_add_func(
mlp_output, mlp_output,
mlp_bias.expand_as(residual), mlp_bias.expand_as(residual),
residual, residual,
self.hidden_dropout) self.hidden_dropout)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output = mpu.make_viewless_tensor(inp = output,
requires_grad = output.requires_grad,
keep_graph = True)
else: else:
out = torch.nn.functional.dropout(mlp_output + mlp_bias, out = torch.nn.functional.dropout(mlp_output + mlp_bias,
p=self.hidden_dropout, p=self.hidden_dropout,
...@@ -611,22 +732,30 @@ class ParallelTransformer(MegatronModule): ...@@ -611,22 +732,30 @@ class ParallelTransformer(MegatronModule):
def __init__(self, init_method, output_layer_init_method, def __init__(self, init_method, output_layer_init_method,
layer_type=LayerType.encoder, layer_type=LayerType.encoder,
self_attn_mask_type=AttnMaskType.padding, self_attn_mask_type=AttnMaskType.padding,
post_layer_norm=True,
pre_process=True, post_process=True, pre_process=True, post_process=True,
drop_path_rate=0.0): drop_path_rate=0.0):
super(ParallelTransformer, self).__init__() super(ParallelTransformer, self).__init__()
args = get_args() args = get_args()
self.layer_type = layer_type
self.model_type = args.model_type
self.bf16 = args.bf16 self.bf16 = args.bf16
self.fp32_residual_connection = args.fp32_residual_connection self.fp32_residual_connection = args.fp32_residual_connection
self.post_layer_norm = post_layer_norm
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.input_tensor = None self.input_tensor = None
self.drop_path_rate = drop_path_rate self.drop_path_rate = drop_path_rate
# Store activation checkpoiting flag. # Store activation checkpoiting flag.
self.activations_checkpoint_method = args.activations_checkpoint_method self.recompute_granularity = args.recompute_granularity
self.activations_checkpoint_num_layers = args.activations_checkpoint_num_layers self.recompute_method = args.recompute_method
self.distribute_checkpointed_activations = args.distribute_checkpointed_activations self.recompute_num_layers = args.recompute_num_layers
self.distribute_saved_activations = \
args.distribute_saved_activations and not args.sequence_parallel
self.sequence_parallel = args.sequence_parallel
# Number of layers. # Number of layers.
self.num_layers = mpu.get_num_layers( self.num_layers = mpu.get_num_layers(
...@@ -690,12 +819,13 @@ class ParallelTransformer(MegatronModule): ...@@ -690,12 +819,13 @@ class ParallelTransformer(MegatronModule):
self.layers = torch.nn.ModuleList( self.layers = torch.nn.ModuleList(
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) [build_layer(i + 1 + offset) for i in range(self.num_layers)])
if self.post_process: if self.post_process and self.post_layer_norm:
# Final layer norm before output. # Final layer norm before output.
self.final_layernorm = LayerNorm( self.final_layernorm = LayerNorm(
args.hidden_size, args.hidden_size,
eps=args.layernorm_epsilon, eps=args.layernorm_epsilon,
no_persist_layer_norm=args.no_persist_layer_norm) no_persist_layer_norm=args.no_persist_layer_norm,
sequence_parallel=args.sequence_parallel)
def _get_layer(self, layer_number): def _get_layer(self, layer_number):
return self.layers[layer_number] return self.layers[layer_number]
...@@ -715,32 +845,33 @@ class ParallelTransformer(MegatronModule): ...@@ -715,32 +845,33 @@ class ParallelTransformer(MegatronModule):
return x_ return x_
return custom_forward return custom_forward
if self.activations_checkpoint_method == 'uniform': if self.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint # Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk. # the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints. # A method to further reduce memory usage reducing checkpoints.
l = 0 l = 0
while l < self.num_layers: while l < self.num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + self.activations_checkpoint_num_layers), custom(l, l + self.recompute_num_layers),
self.distribute_checkpointed_activations, self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
l += self.activations_checkpoint_num_layers l += self.recompute_num_layers
elif self.activations_checkpoint_method == 'block':
elif self.recompute_method == 'block':
# Checkpoint the input activation of only a set number of individual # Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest. # Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation. # A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers): for l in range(self.num_layers):
if l < self.activations_checkpoint_num_layers: if l < self.recompute_num_layers:
hidden_states = mpu.checkpoint( hidden_states = mpu.checkpoint(
custom(l, l + 1), custom(l, l + 1),
self.distribute_checkpointed_activations, self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else: else:
hidden_states = custom(l, l + 1)( hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else: else:
raise ValueError("Invalid activation checkpoint method.") raise ValueError("Invalid activation recompute method.")
return hidden_states return hidden_states
...@@ -757,21 +888,14 @@ class ParallelTransformer(MegatronModule): ...@@ -757,21 +888,14 @@ class ParallelTransformer(MegatronModule):
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None, enc_dec_attn_mask=None,
inference_params=None): inference_params=None):
# hidden_states: [s, b, h]
# Checks. # Checks.
if inference_params: if inference_params:
assert self.activations_checkpoint_method is None, \ assert self.recompute_granularity is None, \
'inference does not work with activation checkpointing' 'inference does not work with activation checkpointing'
if self.pre_process: if not self.pre_process:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if self.fp32_residual_connection:
hidden_states = hidden_states.transpose(0, 1).contiguous().float()
# Otherwise, leave it as is.
else:
hidden_states = hidden_states.transpose(0, 1).contiguous()
else:
# See set_input_tensor() # See set_input_tensor()
hidden_states = self.input_tensor hidden_states = self.input_tensor
...@@ -792,37 +916,34 @@ class ParallelTransformer(MegatronModule): ...@@ -792,37 +916,34 @@ class ParallelTransformer(MegatronModule):
# is called here to be future-proof and corner-case-proof. # is called here to be future-proof and corner-case-proof.
hidden_states = mpu.make_viewless_tensor( hidden_states = mpu.make_viewless_tensor(
hidden_states, hidden_states,
requires_grad = True, requires_grad=True,
keep_graph = True, keep_graph=True,
) )
# Transpose encoder output. if self.sequence_parallel:
if encoder_output is not None: rng_context = mpu.get_cuda_rng_tracker().fork()
encoder_output = encoder_output.transpose(0, 1).contiguous()
# Forward pass.
if self.activations_checkpoint_method is not None:
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask)
else: else:
for index in range(self.num_layers): rng_context = nullcontext()
layer = self._get_layer(index)
hidden_states = layer( with rng_context:
hidden_states, # Forward pass.
attention_mask, if self.recompute_granularity == 'full':
encoder_output=encoder_output, hidden_states = self._checkpointed_forward(hidden_states,
enc_dec_attn_mask=enc_dec_attn_mask, attention_mask,
inference_params=inference_params) encoder_output,
enc_dec_attn_mask)
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
# Final layer norm. # Final layer norm.
if self.post_process: if self.post_process and self.post_layer_norm:
# Reverting data format change [s b h] --> [b s h]. hidden_states = self.final_layernorm(hidden_states)
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.final_layernorm(hidden_states)
else:
output = hidden_states
return output return hidden_states
...@@ -16,9 +16,11 @@ ...@@ -16,9 +16,11 @@
"""Vision Transformer(VIT) model.""" """Vision Transformer(VIT) model."""
import torch import torch
from torch.nn.init import trunc_normal_
from megatron import get_args from megatron import get_args
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead from megatron.model.vision.vit_backbone import VitBackbone, VitMlpHead
from megatron.model.vision.mit_backbone import mit_b3_avg
from megatron.model.module import MegatronModule from megatron.model.module import MegatronModule
class VitClassificationModel(MegatronModule): class VitClassificationModel(MegatronModule):
...@@ -61,3 +63,35 @@ class VitClassificationModel(MegatronModule): ...@@ -61,3 +63,35 @@ class VitClassificationModel(MegatronModule):
hidden_states = self.head(hidden_states) hidden_states = self.head(hidden_states)
return hidden_states return hidden_states
class MitClassificationModel(MegatronModule):
"""Mix vision Transformer Model."""
def __init__(self, num_classes,
pre_process=True, post_process=True):
super(MitClassificationModel, self).__init__()
args = get_args()
self.hidden_size = args.hidden_size
self.num_classes = num_classes
self.backbone = mit_b3_avg()
self.head = torch.nn.Linear(512, num_classes)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, torch.nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, torch.nn.Linear) and m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def forward(self, input):
hidden_states = self.backbone(input)
hidden_states = self.head(hidden_states)
return hidden_states
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
# copied from https://github.com/facebookresearch/dino/blob/main/main_dino.py
# reworked/refactored some parts to make it run in Megatron.
import math
import apex
import einops
import torch
import numpy as np
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from megatron import get_args, print_rank_0
from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone
from megatron.model.module import MegatronModule
from megatron.model.vision.mit_backbone import mit_b5_avg
from megatron.model.vision.esvit_swin_backbone import get_swin
class DINOLoss(torch.nn.Module):
def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
center_momentum=0.9):
super().__init__()
self.student_temp = student_temp
self.center_momentum = center_momentum
self.ncrops = ncrops
self.register_buffer("center", torch.zeros(1, out_dim))
# we apply a warm up for the teacher temperature because
# a too high temperature makes the training instable at the beginning
self.teacher_temp_schedule = np.concatenate((
np.linspace(warmup_teacher_temp,
teacher_temp, warmup_teacher_temp_epochs),
np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
))
self.teacher_temp = teacher_temp
def forward(self, student_output, teacher_output, iteration):
"""
Cross-entropy between softmax outputs of the teacher
and student network.
"""
args = get_args()
student_out = student_output / self.student_temp
student_out = student_out.chunk(self.ncrops)
epoch = iteration // args.iter_per_epoch
# teacher centering and sharpening
temp = self.teacher_temp_schedule[epoch]
teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
teacher_out = teacher_out.detach().chunk(2)
total_loss = 0
n_loss_terms = 0
for iq, q in enumerate(teacher_out):
for v in range(len(student_out)):
if v == iq:
# we skip cases where student and teacher operate on the same view
continue
loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
total_loss += loss.mean()
n_loss_terms += 1
total_loss /= n_loss_terms
self.update_center(teacher_output)
return total_loss
@torch.no_grad()
def update_center(self, teacher_output):
"""
Update center used for teacher output.
"""
batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
torch.distributed.all_reduce(batch_center)
batch_center = batch_center / (len(teacher_output) * torch.distributed.get_world_size())
self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
class DINOHead(torch.nn.Module):
def __init__(self, in_dim, out_dim, norm_last_layer=True, nlayers=3):
super().__init__()
args = get_args()
hidden_dim = args.dino_head_hidden_size
bottleneck_dim = args.dino_bottleneck_size
nlayers = max(nlayers, 1)
if nlayers == 1:
self.mlp = torch.nn.Linear(in_dim, bottleneck_dim)
else:
layers = [torch.nn.Linear(in_dim, hidden_dim)]
layers.append(torch.nn.GELU())
for _ in range(nlayers - 2):
layers.append(torch.nn.Linear(hidden_dim, hidden_dim))
layers.append(torch.nn.GELU())
layers.append(torch.nn.Linear(hidden_dim, bottleneck_dim))
self.mlp = torch.nn.Sequential(*layers)
self.apply(self._init_weights)
self.last_layer = torch.nn.utils.weight_norm(torch.nn.Linear(bottleneck_dim, out_dim, bias=False))
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False
def _init_weights(self, m):
if isinstance(m, torch.nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, torch.nn.Linear) and m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.mlp(x)
x = torch.nn.functional.normalize(x, dim=-1, p=2)
x = self.last_layer(x)
return x
class MultiCropWrapper(MegatronModule):
"""
Perform forward pass separately on each resolution input.
The inputs corresponding to a single resolution are clubbed and single
forward is run on the same resolution inputs. Hence we do several
forward passes = number of different resolutions used. We then
concatenate all the output features and run the head forward on these
concatenated features.
"""
def __init__(self, backbone, head):
super(MultiCropWrapper, self).__init__()
# disable layers dedicated to ImageNet labels classification
#backbone.fc, backbone.head = torch.nn.Identity(), torch.nn.Identity()
self.backbone = backbone
self.head = head
def forward(self, x):
# convert to list
if not isinstance(x, list):
x = [x]
idx_crops = torch.cumsum(torch.unique_consecutive(
torch.tensor([inp.shape[-1] for inp in x]),
return_counts=True,
)[1], 0)
start_idx = 0
for end_idx in idx_crops:
_out = self.backbone(torch.cat(x[start_idx: end_idx]))
if start_idx == 0:
output = _out
else:
output = torch.cat((output, _out))
start_idx = end_idx
# Run the head forward on the concatenated features.
if self.training:
return self.head(output)
else:
return output
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep,
warmup_epochs=0, start_warmup_value=0):
warmup_schedule = np.array([])
warmup_iters = warmup_epochs * niter_per_ep
if warmup_epochs > 0:
warmup_schedule = \
np.linspace(start_warmup_value, base_value, warmup_iters)
iters = np.arange(epochs * niter_per_ep - warmup_iters)
schedule = final_value + 0.5 * (base_value - final_value) \
* (1 + np.cos(np.pi * iters / len(iters)))
schedule = np.concatenate((warmup_schedule, schedule))
assert len(schedule) == epochs * niter_per_ep
return schedule
def get_student_backbone_and_num_features(pre_process=True, post_process=True):
args = get_args()
if args.vision_backbone_type == 'vit':
student = VitBackbone(pre_process=pre_process,
post_process=post_process,
drop_path_rate=0.1,
single_token_output=True)
num_features = args.hidden_size
elif args.vision_backbone_type == 'mit':
student = mit_b5_avg(drop_path_rate=0.1)
num_features = 512
elif args.vision_backbone_type == 'swin':
student = get_swin()
num_features = student.num_features
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return student, num_features
def get_teacher_backbone_and_num_features(pre_process=True, post_process=True):
args = get_args()
if args.vision_backbone_type == 'vit':
teacher = VitBackbone(pre_process=pre_process,
post_process=post_process,
single_token_output=True)
num_features = args.hidden_size
elif args.vision_backbone_type == 'mit':
teacher = mit_b5_avg(drop_path_rate=0.0)
num_features = 512
elif args.vision_backbone_type == 'swin':
teacher = get_swin(is_teacher=True)
num_features = teacher.num_features
else:
raise Exception('{} vision backbone is not supported.'.format(
args.vision_backbone_type))
return teacher, num_features
class DINOPretrainModel(MegatronModule):
def __init__(self, pre_process=True, post_process=True):
super(DINOPretrainModel, self).__init__()
args = get_args()
self.out_dim = 65536
self.dino_loss = DINOLoss(
self.out_dim,
args.dino_local_crops_number + 2,
args.dino_warmup_teacher_temp,
args.dino_teacher_temp,
args.dino_warmup_teacher_temp_epochs,
300,
)
self.pre_process = pre_process
self.post_process = post_process
self.momentum_teacher = 0.996
student_backbone, num_features = \
get_student_backbone_and_num_features(pre_process, post_process)
self.student = MultiCropWrapper(
student_backbone,
DINOHead(num_features, self.out_dim,
norm_last_layer=args.dino_norm_last_layer)
)
self.momentum_schedule = cosine_scheduler(
self.momentum_teacher, 1,
args.train_iters // args.iter_per_epoch,
args.iter_per_epoch
)
teacher_backbone, num_features = \
get_teacher_backbone_and_num_features(pre_process, post_process)
self.teacher = MultiCropWrapper(
teacher_backbone,
DINOHead(num_features, self.out_dim)
)
self.teacher.load_state_dict(self.student.state_dict())
for p in self.teacher.parameters():
if hasattr(p, "requires_grad") and p.requires_grad is not None:
p.requires_grad = False
def set_input_tensor(self, tensor):
pass
def forward(self, input):
student_output = None
if self.training:
student_output = self.student(input)
teacher_output = self.teacher(input[:2])
else:
teacher_output = self.teacher(input)
return student_output, teacher_output
def cancel_gradients_last_layer(self, iteration):
args = get_args()
epoch = iteration // args.iter_per_epoch
if epoch < args.dino_freeze_last_layer:
for n, p in self.student.named_parameters():
if "last_layer" in n:
p.grad = None
def update_momentum(self, iteration):
with torch.no_grad():
m = self.momentum_schedule[iteration]
for param_q, param_k in zip(self.student.parameters(), self.teacher.parameters()):
param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
# Copyright (c) 2021 Microsoft
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Modified by Chunyuan Li (chunyl@microsoft.com)
# Swin Transformer
# --------------------------------------------------------
import os
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import torch.distributed as dist
from torch.nn.init import trunc_normal_
from megatron.model.transformer import DropPath
from megatron import get_args
from megatron.model import LayerNorm
import numpy as np
from math import sqrt
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None,
out_features=None, act_layer=nn.GELU, drop=0.):
super(Mlp, self).__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super(WindowAttention, self).__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2 Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0).type(attn.type())
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn_out = attn
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x, attn_out
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
@staticmethod
def compute_macs(module, input, output):
B, N, C = input[0].shape
module.__flops__ += module.flops(N) * B
class SwinTransformerBlock(nn.Module):
r"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.H = input_resolution[0]
self.W = input_resolution[1]
self.attn_mask_dict = {}
def create_attn_mask(self, H, W):
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1)) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x):
B, L, C = x.shape
H = int(sqrt(L))
W = H
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# pad feature maps to multiples of window size
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
if H in self.attn_mask_dict.keys():
attn_mask = self.attn_mask_dict[H]
else:
self.attn_mask_dict[H] = self.create_attn_mask(self.H, self.W).to(x.device)
attn_mask = self.attn_mask_dict[H]
else:
shifted_x = x
attn_mask = None
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows, attn = self.attn(x_windows, attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, attn
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size} mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B, L, C = x.shape
H = int(sqrt(L))
W = H
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
x, _ = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
def forward_with_features(self, x):
fea = []
for blk in self.blocks:
x, _ = blk(x)
fea.append(x)
if self.downsample is not None:
x = self.downsample(x)
return x, fea
def forward_with_attention(self, x):
attns = []
for blk in self.blocks:
x, attn = blk(x)
attns.append(attn)
if self.downsample is not None:
x = self.downsample(x)
return x, attns
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
super().__init__()
img_size = (img_size, img_size)
patch_size = (patch_size, patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class SwinTransformer(nn.Module):
r""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size.
patch_size (int | tuple(int)): Patch size.
in_chans (int): Number of input channels.
num_classes (int): Number of classes for classification head.
embed_dim (int): Embedding dimension.
depths (tuple(int)): Depth of Swin Transformer layers.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate.
drop_path_rate (float): Stochastic depth rate.
norm_layer (nn.Module): normalization layer.
ape (bool): If True, add absolute position embedding to the patch embedding.
patch_norm (bool): If True, add normalization after patch embedding.
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, **kwargs):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
# todo: to be implemented
return {'relative_position_bias_table'}
def forward(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x_region = self.norm(x) # B L C
x = self.avgpool(x_region.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x
def forward_feature_maps(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x_grid = self.norm(x) # B L C
x = self.avgpool(x_grid.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x, x_grid
def forward_selfattention(self, x, n=1):
# n=1 return the last layer attn map; otherwise return attn maps in all layers
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
if n==1:
return self.forward_last_selfattention(x)
else:
return self.forward_all_selfattention(x)
def forward_last_selfattention(self, x):
for i, layer in enumerate(self.layers):
if i < len(self.layers) - 1:
x = layer(x)
else:
x, attns = layer.forward_with_attention(x)
return attns[-1]
def forward_all_selfattention(self, x):
attn_out = []
for layer in self.layers:
x, attns = layer.forward_with_attention(x)
attn_out += attns
return attn_out
def forward_return_n_last_blocks(self, x, n=1, return_patch_avgpool=False, depth=[]):
num_blks = sum(depth)
start_idx = num_blks - n
sum_cur = 0
for i, d in enumerate(depth):
sum_cur_new = sum_cur + d
if start_idx >= sum_cur and start_idx < sum_cur_new:
start_stage = i
start_blk = start_idx - sum_cur
sum_cur = sum_cur_new
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
# we will return the averaged token features from the `n` last blocks
# note: there is no [CLS] token in Swin Transformer
output = []
s = 0
for i, layer in enumerate(self.layers):
x, fea = layer.forward_with_features(x)
if i >= start_stage:
for x_ in fea[start_blk:]:
if i == len(self.layers)-1: # use the norm in the last stage
x_ = self.norm(x_)
x_avg = torch.flatten(self.avgpool(x_.transpose(1, 2)), 1) # B C
# print(f'Stage {i}, x_avg {x_avg.shape}')
output.append(x_avg)
start_blk = 0
return torch.cat(output, dim=-1)
def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
if dist.get_rank() == 0:
print(f"GFLOPs layer_{i}: {layer.flops() / 1e9}")
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
flops += self.num_features * self.num_classes
return flops
def init_weights(self, pretrained='', pretrained_layers=[], verbose=True):
if os.path.isfile(pretrained):
pretrained_dict = torch.load(pretrained, map_location='cpu')
logging.info(f'=> loading pretrained model {pretrained}')
model_dict = self.state_dict()
pretrained_dict = {
k: v for k, v in pretrained_dict.items()
if k in model_dict.keys()
}
need_init_state_dict = {}
for k, v in pretrained_dict.items():
need_init = (
k.split('.')[0] in pretrained_layers
or pretrained_layers[0] is '*'
or 'relative_position_index' not in k
or 'attn_mask' not in k
)
if need_init:
if verbose:
logging.info(f'=> init {k} from {pretrained}')
if 'relative_position_bias_table' in k and v.size() != model_dict[k].size():
relative_position_bias_table_pretrained = v
relative_position_bias_table_current = model_dict[k]
L1, nH1 = relative_position_bias_table_pretrained.size()
L2, nH2 = relative_position_bias_table_current.size()
if nH1 != nH2:
logging.info(f"Error in loading {k}, passing")
else:
if L1 != L2:
logging.info(
'=> load_pretrained: resized variant: {} to {}'
.format((L1, nH1), (L2, nH2))
)
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
size=(S2, S2),
mode='bicubic')
v = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
if 'absolute_pos_embed' in k and v.size() != model_dict[k].size():
absolute_pos_embed_pretrained = v
absolute_pos_embed_current = model_dict[k]
_, L1, C1 = absolute_pos_embed_pretrained.size()
_, L2, C2 = absolute_pos_embed_current.size()
if C1 != C1:
logging.info(f"Error in loading {k}, passing")
else:
if L1 != L2:
logging.info(
'=> load_pretrained: resized variant: {} to {}'
.format((1, L1, C1), (1, L2, C2))
)
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
v = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1).flatten(1, 2)
need_init_state_dict[k] = v
self.load_state_dict(need_init_state_dict, strict=False)
def freeze_pretrained_layers(self, frozen_layers=[]):
for name, module in self.named_modules():
if (
name.split('.')[0] in frozen_layers
or '.'.join(name.split('.')[0:2]) in frozen_layers
or (len(frozen_layers) > 0 and frozen_layers[0] is '*')
):
for _name, param in module.named_parameters():
param.requires_grad = False
logging.info(
'=> set param {} requires grad to False'
.format(name)
)
for name, param in self.named_parameters():
if (
name.split('.')[0] in frozen_layers
or (len(frozen_layers) > 0 and frozen_layers[0] is '*')
and param.requires_grad is True
):
param.requires_grad = False
logging.info(
'=> set param {} requires grad to False'
.format(name)
)
return self
def get_swin(is_teacher=False):
args = get_args()
if args.swin_backbone_type == "tiny":
embed_dim = 96
depths = [2, 2, 6, 2]
num_heads = [3, 6, 12, 24]
drop_path_rate = 0.1
elif args.swin_backbone_type == 'h3':
embed_dim = 384
depths = [2, 2, 18, 2]
num_heads = [6, 12, 24, 48]
drop_path_rate = 0.2
else:
embed_dim = 128
depths = [2, 2, 18, 2]
num_heads = [4, 8, 16, 32]
drop_path_rate = 0.2
swin = SwinTransformer(
img_size=224,
in_chans=3,
num_classes=1000,
patch_size=4,
embed_dim=embed_dim,
depths=depths,
num_heads=num_heads,
window_size=7,
mlp_ratio=4,
qkv_bias=True,
drop_rate=0,
attn_drop_rate=0,
drop_path_rate=(0.0 if is_teacher else drop_path_rate),
norm_layer=partial(LayerNorm, eps=1e-6),
ape=False,
patch_norm=True,
)
return swin
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
i
import math
import apex
import einops
import torch
import torch.nn.functional as F
from megatron import get_args, print_rank_0
from megatron.model.utils import get_linear_layer
from megatron.model.vision.vit_backbone import VitBackbone
from megatron.model.module import MegatronModule
from megatron.model.vision.mit_backbone import mit_b3
from megatron.model.vision.utils import resize_
class VitInpaintingModel(MegatronModule):
def __init__(self, pre_process=True, post_process=True):
super(VitInpaintingModel, self).__init__()
args = get_args()
self.pre_process = pre_process
self.post_process = post_process
self.hidden_size = args.hidden_size
self.backbone = VitBackbone(
pre_process=self.pre_process,
post_process=self.post_process,
class_token=False,
)
self.patch_dim = args.patch_dim
self.img_h = args.img_h
self.img_w = args.img_w
self.seq_length = args.seq_length
# full mask
if self.post_process:
self.linear_decoder = get_linear_layer(
self.hidden_size,
self.backbone.flatten_dim,
torch.nn.init.zeros_
)
def set_input_tensor(self, input_tensor):
self.backbone.set_input_tensor(input_tensor)
def forward(self, input):
hidden_states = self.backbone(input)
if not self.post_process:
return hidden_states
decoded_output = self.linear_decoder(hidden_states)
output = einops.rearrange(
decoded_output,
"b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
p1=self.patch_dim,
p2=self.patch_dim,
h=self.img_h//self.patch_dim,
w=self.img_w//self.patch_dim,
)
return output
class MLP(torch.nn.Module):
"""
Linear Embedding
"""
def __init__(self, input_dim=2048, embed_dim=768):
super().__init__()
self.proj = torch.nn.Linear(input_dim, embed_dim)
def forward(self, x):
x = x.flatten(2).transpose(1, 2)
x = self.proj(x)
return x
class MitInpaintingModel(MegatronModule):
"""Mix vision Transformer Model."""
def __init__(self, pre_process=True, post_process=True):
super(MitInpaintingModel, self).__init__()
self.pre_process = pre_process
self.post_process = post_process
args = get_args()
self.patch_dim = args.patch_dim
self.img_h = args.img_h
self.img_w = args.img_w
self.flatten_dim = self.patch_dim * self.patch_dim * 3
self.backbone = mit_b3()
self.in_channels = [64, 128, 320, 512]
self.embedding_dim = 768
c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=self.embedding_dim)
self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=self.embedding_dim)
self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=self.embedding_dim)
self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=self.embedding_dim)
self.conv_fuse = torch.nn.Conv2d(self.embedding_dim*4, self.embedding_dim, 1, 1, bias=False)
self.norm = apex.parallel.SyncBatchNorm(self.embedding_dim)
self.dropout = torch.nn.Dropout2d(0.1)
self.linear_pred = torch.nn.Conv2d(self.embedding_dim, self.flatten_dim, kernel_size=1)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def forward(self, input):
c1, c2, c3, c4 = self.backbone(input)
n, _, h, w = c4.shape
_c4 = self.linear_c4(c4).permute(0, 2, 1).reshape(n, -1, c4.shape[2], c4.shape[3])
_c4 = resize(_c4, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c3 = self.linear_c3(c3).permute(0, 2, 1).reshape(n, -1, c3.shape[2], c3.shape[3])
_c3 = resize(_c3, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c2 = self.linear_c2(c2).permute(0, 2, 1).reshape(n, -1, c2.shape[2], c2.shape[3])
_c2 = resize(_c2, size=c1.size()[2:], mode='bilinear', align_corners=False)
_c1 = self.linear_c1(c1).permute(0, 2, 1).reshape(n, -1, c1.shape[2], c1.shape[3])
_c = torch.cat([_c4, _c3, _c2, _c1], dim=1)
_c = self.conv_fuse(_c)
x = self.norm(_c)
x = F.relu(x, inplace=True)
x = self.dropout(x)
x = self.linear_pred(x)
output = einops.rearrange(
x,
"b (c p1 p2) h w -> b c (h p1) (w p2)",
p1=self.patch_dim,
p2=self.patch_dim,
h=self.img_h//self.patch_dim,
w=self.img_w//self.patch_dim,
)
return output
import torch.nn.functional as F
import torch
from megatron import print_rank_0, get_args, mpu
from megatron.data.vit_dataset import ClassificationTransform
from megatron.data.image_folder import ImageFolder
_FEATURE_BANK = None
def build_data_loader(dataset, drop_last=True, shuffle=False):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
args = get_args()
micro_batch_size = 16
num_workers = args.num_workers
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank,
drop_last=drop_last, shuffle=shuffle
)
# Data loader. Note that batch size is the per GPU batch size.
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=micro_batch_size,
sampler=sampler,
shuffle=False,
num_workers=num_workers,
drop_last=not drop_last,
pin_memory=True,
)
return data_loader
def compute_feature_bank(model):
args = get_args()
global _FEATURE_BANK
feature_bank = []
feature_label = []
train_ds = ImageFolder(
root=args.data_path[0],
transform=ClassificationTransform((args.img_h, args.img_w), train=False),
data_per_class_fraction=1.0
)
classes = len(train_ds.classes)
dataloader = build_data_loader(train_ds)
for m in model:
m.eval()
with torch.no_grad():
for i, batch in enumerate(dataloader):
images = batch[0].cuda().contiguous()
labels = batch[1].cuda().contiguous()
student_feature, teacher_feature = model[0](images)
feature = F.normalize(teacher_feature.float(), dim=1)
feature_bank.append(feature)
feature_label.append(labels)
for m in model:
m.train()
# [N', D]
feature_bank = torch.cat(feature_bank, dim=0).contiguous()
feature_label = torch.cat(feature_label, dim=0).contiguous()
feature_banks = [torch.zeros_like(feature_bank)
for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather(feature_banks,
feature_bank,
group=mpu.get_data_parallel_group())
assert torch.all(torch.eq(feature_banks[mpu.get_data_parallel_rank()],
feature_bank))
feature_labels = [torch.zeros_like(feature_label)
for i in range(mpu.get_data_parallel_world_size())]
torch.distributed.all_gather(feature_labels,
feature_label,
group=mpu.get_data_parallel_group())
# [D, N]
feature_banks = torch.cat(feature_banks, dim=0).t().contiguous()
# [N]
feature_labels = torch.cat(feature_labels, dim=0).contiguous()
print_rank_0("feature_banks size is {}".format(feature_banks.size()))
print_rank_0("feature labels size is {}".format(feature_labels.size()))
_FEATURE_BANK = (feature_banks, feature_labels, classes)
def get_feature_bank():
global _FEATURE_BANK
assert _FEATURE_BANK is not None
return _FEATURE_BANK
# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and
# https://github.com/leftthomas/SimCLR
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
# compute cos similarity between each feature vector and feature bank ---> [B, N]
sim_matrix = torch.mm(feature, feature_bank)
# [B, K]
sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
# [B, K]
sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1),
dim=-1,
index=sim_indices)
sim_weight = (sim_weight / knn_t).exp()
# counts for each class
one_hot_label = torch.zeros(feature.size(0) * knn_k,
classes,
device=sim_labels.device)
# [B*K, C]
one_hot_label = one_hot_label.scatter(dim=-1,
index=sim_labels.view(-1, 1),
value=1.0)
# weighted score ---> [B, C]
pred_scores = torch.sum(
one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1),
dim=1)
pred_labels = pred_scores.argsort(dim=-1, descending=True)
return pred_labels
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# found in the LICENSE file in the root directory of this
# source tree.
# ---------------------------------------------------------------
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from torch.nn.init import trunc_normal_
from megatron.model.transformer import DropPath
from megatron.model import LayerNorm
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
sr_ratio=1):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = LayerNorm(dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm, sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class OverlapPatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
img_size = (img_size, img_size)
patch_size = (patch_size, patch_size)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class MixVisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNorm,
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], output_avg=False):
super().__init__()
self.num_classes = num_classes
self.depths = depths
self.output_avg = output_avg
# patch_embed
self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
embed_dim=embed_dims[0])
self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
embed_dim=embed_dims[1])
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
embed_dim=embed_dims[2])
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
embed_dim=embed_dims[3])
# transformer encoder
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
self.block1 = nn.ModuleList([Block(
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[0])
for i in range(depths[0])])
self.norm1 = norm_layer(embed_dims[0])
cur += depths[0]
self.block2 = nn.ModuleList([Block(
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[1])
for i in range(depths[1])])
self.norm2 = norm_layer(embed_dims[1])
cur += depths[1]
self.block3 = nn.ModuleList([Block(
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[2])
for i in range(depths[2])])
self.norm3 = norm_layer(embed_dims[2])
cur += depths[2]
self.block4 = nn.ModuleList([Block(
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
sr_ratio=sr_ratios[3])
for i in range(depths[3])])
self.norm4 = norm_layer(embed_dims[3])
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def reset_drop_path(self, drop_path_rate):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
cur = 0
for i in range(self.depths[0]):
self.block1[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[0]
for i in range(self.depths[1]):
self.block2[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[1]
for i in range(self.depths[2]):
self.block3[i].drop_path.drop_prob = dpr[cur + i]
cur += self.depths[2]
for i in range(self.depths[3]):
self.block4[i].drop_path.drop_prob = dpr[cur + i]
def freeze_patch_emb(self):
self.patch_embed1.requires_grad = False
def forward_features(self, x):
B = x.shape[0]
outs = []
# stage 1
x, H, W = self.patch_embed1(x)
for i, blk in enumerate(self.block1):
x = blk(x, H, W)
x = self.norm1(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 2
x, H, W = self.patch_embed2(x)
for i, blk in enumerate(self.block2):
x = blk(x, H, W)
x = self.norm2(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 3
x, H, W = self.patch_embed3(x)
for i, blk in enumerate(self.block3):
x = blk(x, H, W)
x = self.norm3(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
# stage 4
x, H, W = self.patch_embed4(x)
for i, blk in enumerate(self.block4):
x = blk(x, H, W)
x = self.norm4(x)
if not self.output_avg:
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
outs.append(x)
return outs
def forward(self, x):
x = self.forward_features(x)
if self.output_avg:
x = x[3].mean(dim=1)
return x
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
class mit_b0(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b0, self).__init__(
patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b1(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b1, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b2(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b2, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b3(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b3, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b3_avg(MixVisionTransformer):
def __init__(self, drop_path_rate=0.1, **kwargs):
super(mit_b3_avg, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=drop_path_rate, output_avg=True)
class mit_b4(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b4, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b5(MixVisionTransformer):
def __init__(self, **kwargs):
super(mit_b5, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=0.1)
class mit_b5_avg(MixVisionTransformer):
def __init__(self, drop_path_rate=0.1, **kwargs):
super(mit_b5_avg, self).__init__(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
qkv_bias=True, norm_layer=partial(LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
drop_rate=0.0, drop_path_rate=drop_path_rate, output_avg=True)
# Copyright (c) 2021 Microsoft
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Swin Transformer
# --------------------------------------------------------
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from math import sqrt
from megatron import get_args
from functools import partial
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None,
out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.H = input_resolution[0]
self.W = input_resolution[1]
self.attn_mask_dict = {}
def create_attn_mask(self, H, W):
# calculate attention mask for SW-MSA
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1)) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x):
B, L, C = x.shape
H = int(sqrt(L))
W = H
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
x_b4_ds = x
if self.downsample is not None:
x = self.downsample(x)
return x_b4_ds, x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class SwinTransformer(nn.Module):
r""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3,
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.3,
norm_layer=partial(nn.LayerNorm, eps=1e-6), ape=False, patch_norm=True,
use_checkpoint=False, output_avg=False, **kwargs):
super().__init__()
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
self.img_size = to_2tuple(img_size)
self.patch_size = to_2tuple(patch_size)
self.output_avg = output_avg
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layer)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
h = self.img_size[0] // self.patch_size[0]
w = self.img_size[1] // self.patch_size[1]
outs = []
for i, layer in enumerate(self.layers):
px, x = layer(x)
b, n, c = px.shape
if i != len(self.layers) - 1 or not self.output_avg:
px = px.permute(0, 2, 1).contiguous()
px = px.reshape(b, c, h, w)
# is this a fair assumption ?? i think it's baked into the architecture
h, w = h//2, w//2
outs.append(px)
if self.output_avg:
return outs[-1].mean(dim=1)
return outs
def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
flops += self.num_features * self.num_classes
return flops
def get_swin(drop_path_rate=0.3, output_avg=False):
args = get_args()
window_size = 7
embed_dim = 128
depths = [2, 2, 18, 2]
num_heads = [4, 8, 16, 32]
swin = SwinTransformer(
img_size=(args.img_h, args.img_w,),
in_chans=3,
patch_size=args.patch_dim,
embed_dim=embed_dim,
depths=depths,
num_heads=num_heads,
window_size=window_size,
drop_path_rate=drop_path_rate,
output_avg=output_avg,
)
return swin
import warnings
import torch
import torch.nn.functional as F
def resize(input,
size=None,
scale_factor=None,
mode='nearest',
align_corners=None,
warning=True):
if warning:
if size is not None and align_corners:
input_h, input_w = tuple(int(x) for x in input.shape[2:])
output_h, output_w = tuple(int(x) for x in size)
if output_h > input_h or output_w > output_h:
if ((output_h > 1 and output_w > 1 and input_h > 1
and input_w > 1) and (output_h - 1) % (input_h - 1)
and (output_w - 1) % (input_w - 1)):
warnings.warn(
f'When align_corners={align_corners}, '
'the output would more aligned if '
f'input size {(input_h, input_w)} is `x+1` and '
f'out size {(output_h, output_w)} is `nx+1`')
if isinstance(size, torch.Size):
size = tuple(int(x) for x in size)
return F.interpolate(input, size, scale_factor, mode, align_corners)
...@@ -21,7 +21,6 @@ import torch ...@@ -21,7 +21,6 @@ import torch
import apex import apex
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron.model import LayerNorm
from megatron.model.transformer import ParallelTransformer from megatron.model.transformer import ParallelTransformer
from megatron.model.utils import ( from megatron.model.utils import (
get_linear_layer, get_linear_layer,
...@@ -147,7 +146,9 @@ class VitBackbone(MegatronModule): ...@@ -147,7 +146,9 @@ class VitBackbone(MegatronModule):
pre_process=True, pre_process=True,
post_process=True, post_process=True,
class_token=True, class_token=True,
single_token_output=False): single_token_output=False,
post_layer_norm=True,
drop_path_rate=0.0):
super(VitBackbone, self).__init__(share_word_embeddings=False) super(VitBackbone, self).__init__(share_word_embeddings=False)
args = get_args() args = get_args()
...@@ -164,12 +165,14 @@ class VitBackbone(MegatronModule): ...@@ -164,12 +165,14 @@ class VitBackbone(MegatronModule):
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.class_token = class_token self.class_token = class_token
self.post_layer_norm = post_layer_norm
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.patch_dim = args.patch_dim self.patch_dim = args.patch_dim
self.img_h = args.img_h self.img_h = args.img_h
self.img_w = args.img_w self.img_w = args.img_w
self.micro_batch_size = args.micro_batch_size self.micro_batch_size = args.micro_batch_size
self.single_token_output = single_token_output self.single_token_output = single_token_output
self.drop_path_rate = drop_path_rate
assert self.img_h % self.patch_dim == 0 assert self.img_h % self.patch_dim == 0
assert self.img_w % self.patch_dim == 0 assert self.img_w % self.patch_dim == 0
...@@ -216,6 +219,8 @@ class VitBackbone(MegatronModule): ...@@ -216,6 +219,8 @@ class VitBackbone(MegatronModule):
self.scaled_init_method, self.scaled_init_method,
pre_process=self.pre_process, pre_process=self.pre_process,
post_process=self.post_process, post_process=self.post_process,
post_layer_norm=self.post_layer_norm,
drop_path_rate=self.drop_path_rate
) )
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
......
...@@ -49,17 +49,21 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi ...@@ -49,17 +49,21 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from .initialize import initialize_model_parallel from .initialize import initialize_model_parallel
from .initialize import model_parallel_is_initialized from .initialize import model_parallel_is_initialized
from .layers import LinearWithGradAccumulationAndAsyncCommunication
from .layers import ColumnParallelLinear from .layers import ColumnParallelLinear
from .layers import RowParallelLinear from .layers import RowParallelLinear
from .layers import VocabParallelEmbedding from .layers import VocabParallelEmbedding
from .layers import (set_tensor_model_parallel_attributes, from .layers import (set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes, set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes) copy_tensor_model_parallel_attributes)
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region from .mappings import gather_from_tensor_model_parallel_region
from .mappings import scatter_to_sequence_parallel_region
from .mappings import gather_from_sequence_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
from .random import checkpoint from .random import checkpoint
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
......
...@@ -54,6 +54,12 @@ _POSITION_EMBEDDING_GLOBAL_RANKS = None ...@@ -54,6 +54,12 @@ _POSITION_EMBEDDING_GLOBAL_RANKS = None
# rank when broadcasting from the first or last pipeline stage. # rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None
def is_unitialized(): def is_unitialized():
"""Useful for code segments that may be accessed with or without mpu initialization""" """Useful for code segments that may be accessed with or without mpu initialization"""
return _DATA_PARALLEL_GROUP is None return _DATA_PARALLEL_GROUP is None
...@@ -124,6 +130,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -124,6 +130,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# Build the data-parallel groups. # Build the data-parallel groups.
global _DATA_PARALLEL_GROUP global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GLOBAL_RANKS
assert _DATA_PARALLEL_GROUP is None, \ assert _DATA_PARALLEL_GROUP is None, \
'data parallel group is already initialized' 'data parallel group is already initialized'
all_data_parallel_group_ranks = [] all_data_parallel_group_ranks = []
...@@ -137,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1, ...@@ -137,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
group = torch.distributed.new_group(ranks) group = torch.distributed.new_group(ranks)
if rank in ranks: if rank in ranks:
_DATA_PARALLEL_GROUP = group _DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GLOBAL_RANKS = ranks
# Build the model-parallel groups. # Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GROUP
...@@ -478,11 +486,10 @@ def get_tensor_model_parallel_src_rank(): ...@@ -478,11 +486,10 @@ def get_tensor_model_parallel_src_rank():
def get_data_parallel_src_rank(): def get_data_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank """Calculate the global rank corresponding to the first local rank
in the tensor model parallel group.""" in the data parallel group."""
global_rank = torch.distributed.get_rank() assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \
data_parallel_size = get_data_parallel_world_size() "Data parallel group is not initialized"
num_data_parallel_groups = torch.distributed.get_world_size() // data_parallel_size return _DATA_PARALLEL_GLOBAL_RANKS[0]
return global_rank % num_data_parallel_groups
def get_pipeline_model_parallel_first_rank(): def get_pipeline_model_parallel_first_rank():
......
...@@ -30,20 +30,21 @@ from .initialize import get_tensor_model_parallel_world_size ...@@ -30,20 +30,21 @@ from .initialize import get_tensor_model_parallel_world_size
from .initialize import get_tensor_model_parallel_group from .initialize import get_tensor_model_parallel_group
from .mappings import copy_to_tensor_model_parallel_region from .mappings import copy_to_tensor_model_parallel_region
from .mappings import gather_from_tensor_model_parallel_region from .mappings import gather_from_tensor_model_parallel_region
from .mappings import gather_from_sequence_parallel_region
from .mappings import reduce_from_tensor_model_parallel_region from .mappings import reduce_from_tensor_model_parallel_region
from .mappings import scatter_to_tensor_model_parallel_region from .mappings import scatter_to_tensor_model_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
from .utils import divide from .utils import divide
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
from .utils import VocabUtility from .utils import VocabUtility
from megatron import get_args from megatron import get_args, get_global_memory_buffer
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1, 'partition_dim': -1,
'partition_stride': 1} 'partition_stride': 1}
def param_is_not_tensor_parallel_duplicate(param): def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, 'tensor_model_parallel') and return (hasattr(param, 'tensor_model_parallel') and
param.tensor_model_parallel) or ( param.tensor_model_parallel) or (
...@@ -199,16 +200,37 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -199,16 +200,37 @@ class VocabParallelEmbedding(torch.nn.Module):
return output return output
class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function): class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
""" """
Column-parallel linear layer execution with asynchronous all-reduce Linear layer execution with asynchronous communication and gradient accumulation
execution in backprop. fusion in backprop.
""" """
@staticmethod @staticmethod
def forward(ctx, input, weight, bias): def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
async_grad_allreduce, sequence_parallel):
ctx.save_for_backward(input, weight) ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
output = torch.matmul(input, weight.t()) ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
ctx.sequence_parallel = sequence_parallel
if sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = \
get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
torch.distributed._all_gather_base(
all_gather_buffer,
input,
group=get_tensor_model_parallel_group())
total_input = all_gather_buffer
else:
total_input = input
output = torch.matmul(total_input, weight.t())
if bias is not None: if bias is not None:
output = output + bias output = output + bias
return output return output
...@@ -217,17 +239,75 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function): ...@@ -217,17 +239,75 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, weight = ctx.saved_tensors input, weight = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = \
get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
handle = torch.distributed._all_gather_base(
all_gather_buffer,
input,
group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of intput gradient computation shortly (3us) to have
# gather scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
total_input = all_gather_buffer
else:
total_input = input
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
# Asyncronous all-reduce
handle = torch.distributed.all_reduce( if ctx.sequence_parallel:
grad_input, group=get_tensor_model_parallel_group(), async_op=True) handle.wait()
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated # Convert the tensor shapes to 2D for execution compatibility
_ = torch.empty(1, device=grad_output.device) + 1 grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1],
grad_weight = grad_output.t().matmul(input) grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1],
total_input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.sequence_parallel:
assert not ctx.async_grad_allreduce
dim_size = list(input.size())
sub_grad_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# reduce_scatter
handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input,
group=get_tensor_model_parallel_group(),
async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
if ctx.gradient_accumulation_fusion:
import fused_dense_cuda
fused_dense_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad)
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None grad_bias = grad_output.sum(dim=0) if use_bias else None
handle.wait()
return grad_input, grad_weight, grad_bias if ctx.sequence_parallel:
handle.wait()
return sub_grad_input, grad_weight, grad_bias, None, None, None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
...@@ -240,7 +320,7 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -240,7 +320,7 @@ class ColumnParallelLinear(torch.nn.Module):
input_size: first dimension of matrix A. input_size: first dimension of matrix A.
output_size: second dimension of matrix A. output_size: second dimension of matrix A.
bias: If true, add bias bias: If true, add bias
gather_output: If true, call all-gather on output and make Y avaiable gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set init_method: method to initialize weights. Note that bias is always set
...@@ -305,31 +385,30 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -305,31 +385,30 @@ class ColumnParallelLinear(torch.nn.Module):
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.async_tensor_model_parallel_allreduce = ( self.async_tensor_model_parallel_allreduce = (
not args.no_async_tensor_model_parallel_allreduce and args.async_tensor_model_parallel_allreduce and
world_size > 1) world_size > 1)
self.sequence_parallel = (
args.sequence_parallel and
world_size > 1)
assert not self.async_tensor_model_parallel_allreduce or \
not self.sequence_parallel
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
def forward(self, input_): def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_model_parallel_allreduce: if self.async_tensor_model_parallel_allreduce or \
input_shape = input_.shape self.sequence_parallel:
input_ = input_.view(input_shape[0] * input_shape[1],input_shape[2]) input_parallel = input_
# Maxtrix multiply with asynchronouse all-reduce execution
output_parallel = ColumnParallelLinearWithAsyncAllreduce.apply(
input_, self.weight, bias)
output_parallel = output_parallel.view(
input_shape[0], input_shape[1], output_parallel.shape[1])
else: else:
# Set up backprop all-reduce.
input_parallel = copy_to_tensor_model_parallel_region(input_) input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply.
# Matrix multiply. output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
output_parallel = F.linear(input_parallel, self.weight, bias) input_parallel, self.weight, bias, self.gradient_accumulation_fusion,
self.async_tensor_model_parallel_allreduce, self.sequence_parallel)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
assert not self.sequence_parallel
output = gather_from_tensor_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
else: else:
output = output_parallel output = output_parallel
...@@ -410,11 +489,15 @@ class RowParallelLinear(torch.nn.Module): ...@@ -410,11 +489,15 @@ class RowParallelLinear(torch.nn.Module):
self.bias = Parameter(torch.empty( self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(), self.output_size, device=torch.cuda.current_device(),
dtype=args.params_dtype)) dtype=args.params_dtype))
setattr(self.bias, 'sequence_parallel', args.sequence_parallel)
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.sequence_parallel = args.sequence_parallel
self.gradient_accumulation_fusion = args.gradient_accumulation_fusion
...@@ -423,11 +506,17 @@ class RowParallelLinear(torch.nn.Module): ...@@ -423,11 +506,17 @@ class RowParallelLinear(torch.nn.Module):
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
assert not self.sequence_parallel
input_parallel = scatter_to_tensor_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight) output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply(
input_parallel, self.weight, None,
self.gradient_accumulation_fusion, None, None)
# All-reduce across all the partitions. # All-reduce across all the partitions.
output_ = reduce_from_tensor_model_parallel_region(output_parallel) if self.sequence_parallel:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add: if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_ output = output_ + self.bias if self.bias is not None else output_
output_bias = None output_bias = None
......
...@@ -32,13 +32,13 @@ def _reduce(input_): ...@@ -32,13 +32,13 @@ def _reduce(input_):
return input_ return input_
def _split(input_): def _split_along_last_dim(input_):
"""Split the tensor along its last dimension and keep the """Split the tensor along its last dimension and keep the
corresponding slice.""" corresponding slice."""
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size==1: if world_size == 1:
return input_ return input_
# Split along last dimension. # Split along last dimension.
...@@ -51,12 +51,34 @@ def _split(input_): ...@@ -51,12 +51,34 @@ def _split(input_):
return output return output
def _gather(input_): def _split_along_first_dim(input_):
"""Split the tensor along its first dimension and keep the
corresponding slice."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Split along first dimension.
dim_size = input_.size()[0]
assert dim_size % world_size == 0, \
"First dimension of the tensor should be divisible by tensor parallel size"
local_dim_size = dim_size // world_size
rank = get_tensor_model_parallel_rank()
dim_offset = rank * local_dim_size
output = input_[dim_offset:dim_offset+local_dim_size].contiguous()
return output
def _gather_along_last_dim(input_):
"""Gather tensors and concatinate along the last dimension.""" """Gather tensors and concatinate along the last dimension."""
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size==1: if world_size == 1:
return input_ return input_
# Size and dimension. # Size and dimension.
...@@ -73,6 +95,44 @@ def _gather(input_): ...@@ -73,6 +95,44 @@ def _gather(input_):
return output return output
def _gather_along_first_dim(input_):
"""Gather tensors and concatinate along the first dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device())
torch.distributed._all_gather_base(output, input_.contiguous(),
group=get_tensor_model_parallel_group())
return output
def _reduce_scatter_along_first_dim(input_):
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
assert dim_size[0] % world_size == 0, \
"First dimension of the tensor should be divisible by tensor parallel size"
dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device())
torch.distributed._reduce_scatter_base(output, input_.contiguous(),
group=get_tensor_model_parallel_group())
return output
class _CopyToModelParallelRegion(torch.autograd.Function): class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region.""" """Pass the input to the model parallel region."""
...@@ -110,15 +170,15 @@ class _ScatterToModelParallelRegion(torch.autograd.Function): ...@@ -110,15 +170,15 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
return _split(input_) return _split_along_last_dim(input_)
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return _split(input_) return _split_along_last_dim(input_)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return _gather(grad_output) return _gather_along_last_dim(grad_output)
class _GatherFromModelParallelRegion(torch.autograd.Function): class _GatherFromModelParallelRegion(torch.autograd.Function):
...@@ -126,15 +186,73 @@ class _GatherFromModelParallelRegion(torch.autograd.Function): ...@@ -126,15 +186,73 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
return _gather(input_) return _gather_along_last_dim(input_)
@staticmethod @staticmethod
def forward(ctx, input_): def forward(ctx, input_):
return _gather(input_) return _gather_along_last_dim(input_)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return _split(grad_output) return _split_along_last_dim(grad_output)
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_, tensor_parallel_output_grad=True):
return _gather_along_first_dim(input_)
@staticmethod
def forward(ctx, input_, tensor_parallel_output_grad=True):
ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
return _gather_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
# If the computation graph after the gather operation is
# in the tensor parallel mode, output gradients need to reduce
# scattered and whereas if the computation is duplicated,
# output gradients need to be scattered.
if tensor_parallel_output_grad:
return _reduce_scatter_along_first_dim(grad_output), None
else:
return _split_along_first_dim(grad_output), None
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
# ----------------- # -----------------
...@@ -155,3 +273,16 @@ def scatter_to_tensor_model_parallel_region(input_): ...@@ -155,3 +273,16 @@ def scatter_to_tensor_model_parallel_region(input_):
def gather_from_tensor_model_parallel_region(input_): def gather_from_tensor_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_) return _GatherFromModelParallelRegion.apply(input_)
def scatter_to_sequence_parallel_region(input_):
return _ScatterToSequenceParallelRegion.apply(input_)
def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True):
return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad)
def reduce_scatter_to_sequence_parallel_region(input_):
return _ReduceScatterToSequenceParallelRegion.apply(input_)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment