Commit b14e47f4 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'main' of https://github.com/hpcaitech/FastFold

parents 490cb6f5 05681304
Pipeline #234 failed with stages
in 0 seconds
from .core import init_dap
from .comm import (_reduce, _split, _gather, copy, scatter, reduce, gather, col_to_row, row_to_col)
__all__ = [
'init_dap', '_reduce', '_split', '_gather', 'copy', 'scatter', 'reduce', 'gather', 'col_to_row',
'row_to_col'
]
\ No newline at end of file
from typing import Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from .core import ensure_divisibility
def divide(numerator, denominator):
ensure_divisibility(numerator, denominator)
return numerator // denominator
def _reduce(tensor: Tensor) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor
dist.all_reduce(tensor,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
return tensor
def _split(tensor: Tensor, dim: int = -1) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor
split_size = divide(tensor.shape[dim], gpc.get_world_size(ParallelMode.TENSOR))
tensor_list = torch.split(tensor, split_size, dim=dim)
output = tensor_list[gpc.get_local_rank(ParallelMode.TENSOR)].contiguous()
return output
def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor
if dim == 1 and list(tensor.shape)[0] == 1:
output_shape = list(tensor.shape)
output_shape[1] *= gpc.get_world_size(ParallelMode.TENSOR)
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
tensor_list = output.chunk(gpc.get_world_size(ParallelMode.TENSOR), dim=1)
dist.all_gather(list(tensor_list),
tensor,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
else:
tensor_list = [
torch.empty_like(tensor) for _ in range(gpc.get_world_size(ParallelMode.TENSOR))
]
dist.all_gather(tensor_list,
tensor,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
output = torch.cat(tensor_list, dim=dim)
return output
def copy(input: Tensor) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Copy.apply(input)
return input
class Copy(torch.autograd.Function):
@staticmethod
def forward(ctx: "Copy", input: Tensor) -> Tensor:
return input
@staticmethod
def backward(ctx: "Copy", grad_output: Tensor) -> Tensor:
return _reduce(grad_output)
def scatter(input: Tensor, dim: int = -1) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Scatter.apply(input, dim)
else:
input = _split(input, dim=dim)
return input
class Scatter(torch.autograd.Function):
@staticmethod
def forward(ctx: "Scatter", input: Tensor, dim: int = -1) -> Tensor:
ctx.save_for_backward(torch.tensor([dim]))
return _split(input, dim=dim)
@staticmethod
def backward(ctx: "Scatter", grad_output: Tensor) -> Tuple[Tensor]:
dim, = ctx.saved_tensors
return _gather(grad_output, dim=int(dim)), None
def reduce(input: Tensor) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Reduce.apply(input)
else:
input = _reduce(input)
return input
class Reduce(torch.autograd.Function):
@staticmethod
def forward(ctx: "Reduce", input: Tensor) -> Tensor:
return _reduce(input)
@staticmethod
def backward(ctx: "Reduce", grad_output: Tensor) -> Tensor:
return grad_output
def gather(input: Tensor, dim: int = -1) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Gather.apply(input, dim)
else:
input = _gather(input, dim=dim)
return input
class Gather(torch.autograd.Function):
@staticmethod
def forward(ctx: "Gather", input: Tensor, dim: int = -1) -> Tensor:
ctx.save_for_backward(torch.tensor([dim]))
return _gather(input, dim=dim)
@staticmethod
def backward(ctx: "Gather", grad_output: Tensor) -> Tuple[Tensor]:
dim, = ctx.saved_tensors
return _split(grad_output, dim=int(dim)), None
def _all_to_all(tensor: Tensor, in_dim: int = -1, out_dim: int = -1) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor
split_size = divide(tensor.shape[in_dim], gpc.get_world_size(ParallelMode.TENSOR))
input_tensor_list = torch.split(tensor, split_size, dim=in_dim)
input_tensor_list = [tensor_.contiguous() for tensor_ in input_tensor_list]
if out_dim == 1:
output_shape = list(input_tensor_list[0].shape)
output_shape[1] *= gpc.get_world_size(ParallelMode.TENSOR)
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
output_tensor_list = output.chunk(gpc.get_world_size(ParallelMode.TENSOR), dim=1)
dist.all_to_all(list(output_tensor_list),
input_tensor_list,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
else:
output_tensor_list = [torch.ones_like(tensor_) for tensor_ in input_tensor_list]
dist.all_to_all(output_tensor_list,
input_tensor_list,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
output = torch.cat(output_tensor_list, dim=out_dim)
return output
def col_to_row(input_: Tensor) -> Tensor:
if torch.is_grad_enabled() and input_.requires_grad:
input_ = All_to_All.apply(input_, 1, 2)
else:
input_ = _all_to_all(input_, in_dim=1, out_dim=2)
return input_
def row_to_col(input_: Tensor) -> Tensor:
if torch.is_grad_enabled() and input_.requires_grad:
input_ = All_to_All.apply(input_, 2, 1)
else:
input_ = _all_to_all(input_, in_dim=2, out_dim=1)
return input_
class All_to_All(torch.autograd.Function):
@staticmethod
def forward(ctx: "All_to_All", input_: Tensor, in_dim: int = -1, out_dim: int = -1) -> Tensor:
ctx.save_for_backward(torch.tensor([in_dim, out_dim]))
return _all_to_all(input_, in_dim=in_dim, out_dim=out_dim)
@staticmethod
def backward(ctx: "All_to_All", grad_output: Tensor) -> Tuple[Tensor]:
saved_tensors = ctx.saved_tensors[0]
return _all_to_all(grad_output, in_dim=int(saved_tensors[1]),
out_dim=int(saved_tensors[0])), None, None
from typing import Tuple
from einops import rearrange
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from .comm import _split, divide
def broadcast_sync(src: int, tensor: Tensor, host: bool = False) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return 0
if host:
dist.broadcast(tensor,
src=src,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
return 0
else:
output = torch.empty(list(tensor.shape), dtype=tensor.dtype, device=tensor.device)
dist.broadcast(output,
src=src,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
return output
def broadcast_async(src: int, tensor: Tensor, host: bool = False) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return 0
if host:
work = dist.broadcast(tensor,
src=src,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=True)
return work
else:
work = dist.broadcast(tensor,
src=src,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=True)
return work
def broadcast_async_opp(work) -> Tensor:
work.wait()
return 0
def get_rank():
return gpc.get_global_rank()
def get_world_size():
return gpc.get_world_size(ParallelMode.TENSOR)
def _gather_async(tensor: Tensor, dim: int = -1) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor, None
output_shape = list(tensor.shape)
output_shape[1] *= gpc.get_world_size(ParallelMode.TENSOR)
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
tensor_list = output.chunk(gpc.get_world_size(ParallelMode.TENSOR), dim=1)
work = dist.all_gather(list(tensor_list),
tensor,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=True)
return output, work
def gather_async(input: Tensor, dim: int = -1) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input, work = GatherAsync.apply(input, dim)
else:
input, work = _gather_async(input, dim=dim)
return input, work
def gather_async_opp(output: Tensor, work, dim: int = -1) -> Tensor:
if work:
work.wait()
if dim == 2:
output = GatherAsyncOpp.apply(output)
return output
class GatherAsyncOpp(torch.autograd.Function):
@staticmethod
def forward(ctx: "GatherAsyncOpp", input: Tensor) -> Tensor:
mp_size = gpc.get_world_size(ParallelMode.TENSOR)
output = rearrange(input, 'n (x h) w c -> n h (x w) c', x=mp_size)
return output
@staticmethod
def backward(ctx: "GatherAsyncOpp", grad_output: Tensor) -> Tuple[Tensor]:
mp_size = gpc.get_world_size(ParallelMode.TENSOR)
n, h, w, c = grad_output.shape
return grad_output.resize_(n, h * mp_size, int(w / mp_size), c)
class GatherAsync(torch.autograd.Function):
@staticmethod
def forward(ctx: "GatherAsync", input: Tensor, dim: int = -1) -> Tensor:
ctx.dim = dim
return _gather_async(input, dim=dim)
@staticmethod
def backward(ctx: "GatherAsync", grad_output: Tensor, grad_work=None) -> Tuple[Tensor]:
if ctx.dim == 2:
mp_size = gpc.get_world_size(ParallelMode.TENSOR)
n, h, w, c = grad_output.shape
grad_output.resize_(n, int(h / mp_size), w * mp_size, c)
return _split(grad_output, dim=ctx.dim), None
def _all_to_all_async(tensor: Tensor, in_dim: int = -1, out_dim: int = -1) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor, None
split_size = divide(tensor.shape[in_dim], gpc.get_world_size(ParallelMode.TENSOR))
input_tensor_list = torch.split(tensor, split_size, dim=in_dim)
input_tensor_list = [tensor_.contiguous() for tensor_ in input_tensor_list]
output_shape = list(input_tensor_list[0].shape)
output_shape[1] *= gpc.get_world_size(ParallelMode.TENSOR)
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
output_tensor_list = output.chunk(gpc.get_world_size(ParallelMode.TENSOR), dim=1)
work = dist.all_to_all(list(output_tensor_list),
input_tensor_list,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=True)
return output, work
WORLD_WORK_ALL2ALL = None
class All_to_All_Async(torch.autograd.Function):
@staticmethod
def forward(ctx: "All_to_All_Async",
input_: Tensor,
in_dim: int = -1,
out_dim: int = -1) -> Tensor:
ctx.in_dim = in_dim
ctx.out_dim = out_dim
return _all_to_all_async(input_, in_dim=in_dim, out_dim=out_dim)
@staticmethod
def backward(ctx: "All_to_All_Async", grad_output: Tensor, grad_work=None) -> Tuple[Tensor]:
global WORLD_WORK_ALL2ALL
if WORLD_WORK_ALL2ALL:
WORLD_WORK_ALL2ALL.wait()
WORLD_WORK_ALL2ALL = None
if ctx.in_dim == 2:
mp_size = gpc.get_world_size(ParallelMode.TENSOR)
grad_output = rearrange(grad_output, 'n (x h) w c -> n h (x w) c', x=mp_size)
return grad_output, None, None
class All_to_All_Async_Opp(torch.autograd.Function):
@staticmethod
def forward(ctx: "All_to_All_Async_Opp",
output: Tensor,
work,
in_dim: int = -1,
out_dim: int = -1) -> Tensor:
ctx.in_dim = in_dim
ctx.out_dim = out_dim
if work:
work.wait()
if out_dim == 2:
mp_size = gpc.get_world_size(ParallelMode.TENSOR)
output = rearrange(output, 'n (x h) w c -> n h (x w) c', x=mp_size)
return output
@staticmethod
def backward(ctx: "All_to_All_Async_Opp", grad_output: Tensor) -> Tuple[Tensor]:
global WORLD_WORK_ALL2ALL
d_tensor, WORLD_WORK_ALL2ALL = _all_to_all_async(grad_output,
in_dim=ctx.out_dim,
out_dim=ctx.in_dim)
return d_tensor, None, None, None
\ No newline at end of file
import os
import torch
import colossalai
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator)
def set_missing_distributed_environ(key, value):
if key not in os.environ:
os.environ[str(key)] = str(value)
def init_dap(tensor_model_parallel_size_=None):
colossalai.logging.disable_existing_loggers()
if tensor_model_parallel_size_ == None:
if 'WORLD_SIZE' in os.environ:
tensor_model_parallel_size_ = int(os.environ['WORLD_SIZE'])
else:
tensor_model_parallel_size_ = 1
if torch.distributed.is_initialized():
_logger = colossalai.logging.get_dist_logger()
_logger.error(
"use fastfold.distributed.init_dap instead of torch.distributed.init_process_group!")
exit(-1)
# set distributed environ for single device launch
set_missing_distributed_environ('WORLD_SIZE', 1)
set_missing_distributed_environ('RANK', 0)
set_missing_distributed_environ('LOCAL_RANK', 0)
set_missing_distributed_environ('MASTER_ADDR', "localhost")
set_missing_distributed_environ('MASTER_PORT', 18417)
colossalai.launch_from_torch(
config={"parallel": dict(tensor=dict(size=tensor_model_parallel_size_))})
ENABLE_HABANA = False
ENABLE_HMP = False
def enable_habana():
global ENABLE_HABANA
ENABLE_HABANA = True
global ENABLE_LAZY_MODE
ENABLE_LAZY_MODE = True
import habana_frameworks.torch.core
def is_habana():
global ENABLE_HABANA
return ENABLE_HABANA
def enable_hmp():
global ENABLE_HMP
ENABLE_HMP = True
def is_hmp():
global ENABLE_HMP
return ENABLE_HMP
\ No newline at end of file
from .comm import (All_to_All, _gather, _reduce, _split, col_to_row, copy,
gather, reduce, row_to_col, scatter)
from .core import init_dist, get_data_parallel_world_size
__all__ = [
'init_dist', 'get_data_parallel_world_size', '_reduce', '_split', '_gather', 'copy', 'scatter', 'reduce', 'gather',
'col_to_row', 'row_to_col', 'All_to_All'
]
from typing import Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from .core import (ensure_divisibility, get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
def divide(numerator, denominator):
ensure_divisibility(numerator, denominator)
return numerator // denominator
def _reduce(tensor: Tensor) -> Tensor:
if dist.get_world_size() == 1:
return tensor
dist.all_reduce(tensor,
op=dist.ReduceOp.SUM,
group=get_tensor_model_parallel_group(),
async_op=False)
return tensor
def _split(tensor: Tensor, dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
return tensor
split_size = divide(tensor.shape[dim], get_tensor_model_parallel_world_size())
tensor_list = torch.split(tensor, split_size, dim=dim)
output = tensor_list[get_tensor_model_parallel_rank()].contiguous()
return output
def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
return tensor
if dim == 1 and list(tensor.shape)[0] == 1:
output_shape = list(tensor.shape)
output_shape[1] *= get_tensor_model_parallel_world_size()
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=1)
dist.all_gather(list(tensor_list),
tensor,
group=get_tensor_model_parallel_group(),
async_op=False)
else:
tensor_list = [
torch.empty_like(tensor) for _ in range(get_tensor_model_parallel_world_size())
]
dist.all_gather(tensor_list,
tensor,
group=get_tensor_model_parallel_group(),
async_op=False)
output = torch.cat(tensor_list, dim=dim)
return output
def copy(input: Tensor) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Copy.apply(input)
return input
class Copy(torch.autograd.Function):
@staticmethod
def forward(ctx: "Copy", input: Tensor) -> Tensor:
return input
@staticmethod
def backward(ctx: "Copy", grad_output: Tensor) -> Tensor:
return _reduce(grad_output)
def scatter(input: Tensor, dim: int = -1) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Scatter.apply(input, dim)
else:
input = _split(input, dim=dim)
return input
class Scatter(torch.autograd.Function):
@staticmethod
def forward(ctx: "Scatter", input: Tensor, dim: int = -1) -> Tensor:
ctx.save_for_backward(torch.tensor([dim]))
return _split(input, dim=dim)
@staticmethod
def backward(ctx: "Scatter", grad_output: Tensor) -> Tuple[Tensor]:
dim, = ctx.saved_tensors
return _gather(grad_output, dim=int(dim)), None
def reduce(input: Tensor) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Reduce.apply(input)
else:
input = _reduce(input)
return input
class Reduce(torch.autograd.Function):
@staticmethod
def forward(ctx: "Reduce", input: Tensor) -> Tensor:
return _reduce(input)
@staticmethod
def backward(ctx: "Reduce", grad_output: Tensor) -> Tensor:
return grad_output
def gather(input: Tensor, dim: int = -1) -> Tensor:
if torch.is_grad_enabled() and input.requires_grad:
input = Gather.apply(input, dim)
else:
input = _gather(input, dim=dim)
return input
class Gather(torch.autograd.Function):
@staticmethod
def forward(ctx: "Gather", input: Tensor, dim: int = -1) -> Tensor:
ctx.save_for_backward(torch.tensor([dim]))
return _gather(input, dim=dim)
@staticmethod
def backward(ctx: "Gather", grad_output: Tensor) -> Tuple[Tensor]:
dim, = ctx.saved_tensors
return _split(grad_output, dim=int(dim)), None
def _all_to_all(tensor: Tensor, in_dim: int = -1, out_dim: int = -1) -> Tensor:
if dist.get_world_size() == 1:
return tensor
tensor = tensor.transpose(in_dim, 0).contiguous()
output = torch.empty_like(tensor)
dist.all_to_all_single(output, tensor, group=get_tensor_model_parallel_group())
output = output.transpose(in_dim, 0).contiguous()
tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=in_dim)
return torch.cat(tensor_list, dim=out_dim)
def col_to_row(input_: Tensor) -> Tensor:
if torch.is_grad_enabled() and input_.requires_grad:
input_ = All_to_All.apply(input_, 1, 2)
else:
input_ = _all_to_all(input_, in_dim=1, out_dim=2)
return input_
def row_to_col(input_: Tensor) -> Tensor:
if torch.is_grad_enabled() and input_.requires_grad:
input_ = All_to_All.apply(input_, 2, 1)
else:
input_ = _all_to_all(input_, in_dim=2, out_dim=1)
return input_
class All_to_All(torch.autograd.Function):
@staticmethod
def forward(ctx: "All_to_All", input_: Tensor, in_dim: int = -1, out_dim: int = -1) -> Tensor:
ctx.save_for_backward(torch.tensor([in_dim, out_dim]))
return _all_to_all(input_, in_dim=in_dim, out_dim=out_dim)
@staticmethod
def backward(ctx: "All_to_All", grad_output: Tensor) -> Tuple[Tensor]:
saved_tensors = ctx.saved_tensors[0]
return _all_to_all(grad_output, in_dim=int(saved_tensors[1]),
out_dim=int(saved_tensors[0])), None, None
import os
import torch
import torch.distributed as dist
from mpi4py import MPI
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# These values enable us to change the mpu sizes on the fly.
_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_TENSOR_MODEL_PARALLEL_RANK = None
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator)
def set_missing_distributed_environ(key, value):
if key not in os.environ:
os.environ[str(key)] = str(value)
def init_dist(tensor_model_parallel_size_=1):
comm = MPI.COMM_WORLD
world_size = comm.Get_size()
rank = comm.Get_rank()
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12340'
import habana_frameworks.torch.distributed.hccl
dist.init_process_group(backend='hccl', rank=rank, world_size=world_size)
world_size = dist.get_world_size()
rank = dist.get_rank()
# check dist config
ensure_divisibility(world_size, tensor_model_parallel_size_)
data_parallel_size_ = world_size // tensor_model_parallel_size_
# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, \
'data parallel group is already initialized'
for i in range(tensor_model_parallel_size_):
ranks = range(i, world_size, tensor_model_parallel_size_)
group = dist.new_group(ranks)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
global _TENSOR_MODEL_PARALLEL_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
'tensor model parallel group is already initialized'
# Build the model-parallel groups.
for i in range(data_parallel_size_):
ranks = range(i * tensor_model_parallel_size_, (i + 1) * tensor_model_parallel_size_)
group = dist.new_group(ranks)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
if dist.get_rank() == 0:
print('> initialize tensor model parallel with size {}'.format(tensor_model_parallel_size_))
print('> initialize data parallel with size {}'.format(data_parallel_size_))
def dap_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
_DATA_PARALLEL_GROUP is None:
return False
return True
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
'intra_layer_model parallel group is not initialized'
return _TENSOR_MODEL_PARALLEL_GROUP
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \
'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
global _TENSOR_MODEL_PARALLEL_WORLD_SIZE
if _TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
return _TENSOR_MODEL_PARALLEL_WORLD_SIZE
return dist.get_world_size(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
global _TENSOR_MODEL_PARALLEL_RANK
if _TENSOR_MODEL_PARALLEL_RANK is not None:
return _TENSOR_MODEL_PARALLEL_RANK
return dist.get_rank(group=get_tensor_model_parallel_group())
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return dist.get_world_size(group=get_data_parallel_group())
def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return dist.get_rank(group=get_data_parallel_group())
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = dist.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
from functools import partial
from typing import Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from fastfold.habana.distributed import All_to_All, gather, scatter
from fastfold.utils.checkpointing import checkpoint_blocks
from .msa import ExtraMSACore, MSAStack
from .ops import Linear, OutProductMean
from .triangle import PairStack
import habana_frameworks.torch.core as htcore
class Evoformer(nn.Module):
def __init__(self,
c_m: int,
c_z: int,
first_block: bool,
last_block: bool,
is_multimer: bool = False):
super(Evoformer, self).__init__()
self.first_block = first_block
self.last_block = last_block
self.msa = MSAStack(c_m, c_z, p_drop=0.15)
self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32)
self.pair = PairStack(d_pair=c_z)
self.is_multimer = is_multimer
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = dist.get_world_size()
seq_length = pair_mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
m = m.unsqueeze(0)
z = z.unsqueeze(0)
m = torch.nn.functional.pad(m, (0, 0, 0, padding_size))
z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size))
if self.is_multimer:
m = scatter(m, dim=2)
else:
m = scatter(m, dim=1)
z = scatter(z, dim=1)
# msa_mask = msa_mask.unsqueeze(0)
# pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size))
pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size))
if not self.is_multimer:
m = self.msa(m, z, msa_mask)
z = self.communication(m, msa_mask, z)
m = All_to_All.apply(m, 1, 2)
z = self.pair(z, pair_mask)
else:
z = self.communication(m, msa_mask, z)
z_ori = z
m = All_to_All.apply(m, 1, 2)
z = self.pair(z, pair_mask)
m = self.msa(m, z_ori, msa_mask)
if self.last_block:
m = m.squeeze(0)
z = z.squeeze(0)
if self.is_multimer:
m = gather(m, dim=1)
else:
m = gather(m, dim=0)
z = gather(z, dim=0)
m = m[:, :-padding_size, :]
z = z[:-padding_size, :-padding_size, :]
htcore.mark_step()
return m, z
class EvoformerStack(nn.Module):
"""
Main Evoformer trunk.
Implements Algorithm 6.
"""
def __init__(
self,
c_m: int,
c_z: int,
c_s: int,
no_blocks: int,
blocks_per_ckpt: int,
clear_cache_between_blocks: bool = False,
is_multimer: bool = False,
**kwargs,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair channel dimension
c_hidden_msa_att:
Hidden dimension in MSA attention
c_hidden_opm:
Hidden dimension in outer product mean module
c_hidden_mul:
Hidden dimension in multiplicative updates
c_hidden_pair_att:
Hidden dimension in triangular attention
c_s:
Channel dimension of the output "single" embedding
no_heads_msa:
Number of heads used for MSA attention
no_heads_pair:
Number of heads used for pair attention
no_blocks:
Number of Evoformer blocks in the stack
transition_n:
Factor by which to multiply c_m to obtain the MSATransition
hidden dimension
msa_dropout:
Dropout rate for MSA activations
pair_dropout:
Dropout used for pair activations
blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
"""
super(EvoformerStack, self).__init__()
self.blocks_per_ckpt = blocks_per_ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks
self.blocks = nn.ModuleList()
for block_id in range(no_blocks):
block = Evoformer(
c_m=c_m,
c_z=c_z,
first_block=(block_id == 0),
last_block=(block_id == no_blocks - 1),
is_multimer=is_multimer,
)
self.blocks.append(block)
self.linear = Linear(c_m, c_s)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
) for b in self.blocks
]
if torch.is_grad_enabled():
m, z = checkpoint_blocks(
blocks,
args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
else:
for b in blocks:
m, z = b(m, z)
s = self.linear(m[..., 0, :, :])
htcore.mark_step()
return m, z, s
class ExtraMSABlock(nn.Module):
def __init__(self,
c_m: int,
c_z: int,
first_block: bool,
last_block: bool,
is_multimer: bool = False):
super(ExtraMSABlock, self).__init__()
self.first_block = first_block
self.last_block = last_block
self.msa_stack = ExtraMSACore(c_m, c_z, p_drop=0.15)
self.communication = OutProductMean(n_feat=c_m, n_feat_out=c_z, n_feat_proj=32)
self.pair_stack = PairStack(d_pair=c_z)
self.is_multimer = is_multimer
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
htcore.mark_step()
dap_size = dist.get_world_size()
seq_length = pair_mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
if self.first_block:
m = m.unsqueeze(0)
z = z.unsqueeze(0)
m = torch.nn.functional.pad(m, (0, 0, 0, padding_size))
z = torch.nn.functional.pad(z, (0, 0, 0, padding_size, 0, padding_size))
if self.is_multimer:
m = scatter(m, dim=2)
else:
m = scatter(m, dim=1)
z = scatter(z, dim=1)
msa_mask = msa_mask.unsqueeze(0)
pair_mask = pair_mask.unsqueeze(0)
msa_mask = torch.nn.functional.pad(msa_mask, (0, padding_size))
pair_mask = torch.nn.functional.pad(pair_mask, (0, padding_size, 0, padding_size))
if not self.is_multimer:
m = self.msa_stack(m, z, msa_mask)
z = self.communication(m, msa_mask, z)
m = All_to_All.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
else:
z = self.communication(m, msa_mask, z)
z_ori = z
m = All_to_All.apply(m, 1, 2)
z = self.pair_stack(z, pair_mask)
m = self.msa_stack(m, z_ori, msa_mask)
if self.last_block:
m = m.squeeze(0)
z = z.squeeze(0)
if self.is_multimer:
m = gather(m, dim=1)
else:
m = gather(m, dim=0)
z = gather(z, dim=0)
m = m[:, :-padding_size, :]
z = z[:-padding_size, :-padding_size, :]
htcore.mark_step()
return m, z
class ExtraMSAStack(nn.Module):
"""
Implements Algorithm 18.
"""
def __init__(
self,
c_m: int,
c_z: int,
no_blocks: int,
blocks_per_ckpt: int,
clear_cache_between_blocks: bool = False,
is_multimer: bool = False,
**kwargs,
):
super(ExtraMSAStack, self).__init__()
self.blocks_per_ckpt = blocks_per_ckpt
self.blocks = nn.ModuleList()
for block_id in range(no_blocks):
block = ExtraMSABlock(
c_m=c_m,
c_z=c_z,
first_block=(block_id == 0),
last_block=(block_id == no_blocks - 1),
is_multimer=is_multimer,
)
self.blocks.append(block)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
) -> torch.Tensor:
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
) for b in self.blocks
]
if torch.is_grad_enabled():
m, z = checkpoint_blocks(
blocks,
args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
else:
for b in blocks:
m, z = b(m, z)
return z
# CustomOp API Usage in PyTorch
This README provides an example of how to write custom PyTorch Ops using a TPC Kernel supported on an HPU device. For more details, refer to [PyTorch CustomOP API](https://docs.habana.ai/en/latest/PyTorch/PyTorch_CustomOp_API/page_index.html) documentation.
For further information on training deep learning models using Gaudi, refer to [developer.habana.ai](https://developer.habana.ai/resources/).
## Table of Contents
* [Model-References](../../../README.md)
* [Prerequisites](#prerequisites)
* [Content](#content)
* [Build and Run with Custom Kernels](#build-and-run-with-custom-kernels)
* [Important to Know](#important-to-know)
* [Applying CustomOps to a Real Training Model Example](#applying-customops-to-a-real-training-model-example)
* [Known Issues](#known-issues)
## Prerequisites
- A TPC kernel on which the HpuKernel will run. To write a CustomOp, you must define the TPC kernel that HpuKernel will run on first. This document provides the required steps for using the existing default TPC kernels `relu_fwd_f32`, `relu_bwd_f32` as we all as the custom kernel `custom_op::custom_relu` to implement CustomOp. For further information on how to write TPC kernels, refer to the [Habana Custom Kernel GitHub page](https://github.com/HabanaAI/Habana_Custom_Kernel).
- **habana-torch-plugin** Python package must be installed. Make sure to install by following the instructions detailed in the [Installation Guide](https://docs.habana.ai/en/latest/Installation_Guide/index.html).
## Content
- C++ file with **custom_op::fusedsoftmax**, **custom_op::fusedsoftmax_bias** definition and Kernel implementation on HPU:
- `fusedsoftmax` performs a fused softmax on input and mask.
- `fusedsoftmax_bias` performs a fused softmax on input, mask and bias
- `setup.py` file for building the solution:
- To compile to Op on Gaudi, run ```python setup.py build```.
- To compile to Op on Gaudi2, run ```python setup2.py build```.
- Python test to run and validate `fusedsoftmax` and `fusedsoftmax_bias`:
- ```python hpu_fusedsoftmax_test.py```
## Build and Run with Custom Kernels
To build and run `fused_softmax` and `fusedsoftmax_bias`, run the following:
```python setup.py build```
## Important to Know
This is an example of an Op implementing both forward and backward.
The forward and backward CustomOp is used for training the model by extending the [torch.autograd](https://pytorch.org/docs/stable/notes/extending.html) package.
## Known Issues
BF16 or HMP is not supported yet. To use CustomOp in topology, run FP32 variant only.
###############################################################################
# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company
###############################################################################
from .fusedsoftmax import fused_softmax, fused_softmax_bias
__all__ = [fused_softmax, fused_softmax_bias]
###############################################################################
# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company
###############################################################################
import torch
import os
import habana_frameworks.torch.core
custom_fusedsoftmax_op_lib_path = "./build/lib.linux-x86_64-3.8/hpu_fusedsoftmax.cpython-38-x86_64-linux-gnu.so"
my_dir = os.path.realpath(__file__)
my_len = my_dir.rfind('/')
base_dir = my_dir[:my_len]
torch.ops.load_library(os.path.join(base_dir, custom_fusedsoftmax_op_lib_path))
class FusedSoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, dim):
# ctx is a context object that can be used to stash information
# for backward computation
tensor = torch.ops.custom_op.fusedsoftmax(input, mask, dim)
ctx.y = tensor
ctx.dim = dim
return tensor
@staticmethod
def backward(ctx, grad_output):
if grad_output is None:
return None
y = ctx.y
ctx.y = None
dim = ctx.dim
ctx.dim = None
grad_input = torch.ops.custom_op.fusedsoftmax_backward(y, grad_output, dim)
return grad_input, None, None
class FusedSoftmaxBiasFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, bias, dim):
# ctx is a context object that can be used to stash information
# for backward computation
tensor = torch.ops.custom_op.fusedsoftmax_bias(input, mask, bias, dim)
ctx.y = tensor
ctx.dim = dim
ctx.use_bias = False
if bias is not None:
ctx.use_bias = True
return tensor
@staticmethod
def backward(ctx, grad_output):
if grad_output is None:
return None
y = ctx.y
ctx.y = None
dim = ctx.dim
ctx.dim = None
grad_input = torch.ops.custom_op.fusedsoftmax_backward(y, grad_output, dim)
grad_bias = None
if ctx.use_bias:
grad_bias = torch.sum(grad_input, dim=-4, keepdim=True)
return grad_input, None, grad_bias, None
ENABLE_OPT = True
def fused_softmax(input, mask, dim):
if ENABLE_OPT and input[..., :, :1, :1, :].shape == mask.shape:
return FusedSoftmaxFunction.apply(input, mask, dim)
else:
input += mask
return torch.softmax(input, dim=dim)
def fused_softmax_bias(input, mask, bias, dim):
if ENABLE_OPT and input[..., :, :1, :1, :].shape == mask.shape and input[..., :1, :, :, :].shape == bias.shape:
return FusedSoftmaxBiasFunction.apply(input, mask, bias, dim)
else:
input += mask
input += bias
return torch.softmax(input, dim=dim)
/******************************************************************************
###############################################################################
# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company
###############################################################################
*******************************************************************************/
#include "hpu_custom_op.h"
#include <torch/extension.h>
#include <perf_lib_layer_params.h>
struct SoftMaxParam
{
int32_t axis;
bool with_bias;
};
bool register_fusedsoftmax() {
// Registering custom_op::fusedsoftmax
// inputs desc
habana::custom_op::InputDesc input_a_desc{
habana::custom_op::input_type::TENSOR, 0};
habana::custom_op::InputDesc input_b_desc{
habana::custom_op::input_type::TENSOR, 1};
habana::custom_op::InputDesc input_d_desc{
habana::custom_op::input_type::USER_PARAMS, 2};
std::vector<habana::custom_op::InputDesc> inputs_desc{
input_a_desc, input_b_desc, input_d_desc};
// output desc
// output shape callback
auto output_size_lambda =
[](const at::Stack& inputs) -> std::vector<int64_t> {
auto self = inputs[0].toTensor(); // input
std::vector<int64_t> result_sizes = self.sizes().vec();
return result_sizes;
};
habana::custom_op::OutputDesc output_desc{
0, c10::ScalarType::Float, output_size_lambda};
std::vector<habana::custom_op::OutputDesc> outputs_desc{
output_desc};
// user param callback
auto user_params_lambda = [](const at::Stack& inputs, size_t& size) {
HPU_PARAMS_STUB(SoftMaxParam);
params->with_bias = false;
int dim = inputs[2].toInt();
if (dim > 0)
params->axis = inputs[0].toTensor().dim() - dim - 1;
else
params->axis = - dim - 1;
return params;
};
// actual register
REGISTER_CUSTOM_OP_ATTRIBUTES(
"custom_op::fusedsoftmax", //schema name
#ifdef GAUDI2
"fusedsoftmax_fwd_f32_gaudi2", // guid
#else
"fusedsoftmax_fwd_f32", // guid
#endif
inputs_desc,
outputs_desc,
user_params_lambda);
std::cout << "cpp registered custom_op::fusedsoftmax\n";
return true;
}
bool register_fusedsoftmax_bias() {
// Registering custom_op::fusedsoftmax
// inputs desc
habana::custom_op::InputDesc input_a_desc{
habana::custom_op::input_type::TENSOR, 0};
habana::custom_op::InputDesc input_b_desc{
habana::custom_op::input_type::TENSOR, 1};
habana::custom_op::InputDesc input_c_desc{
habana::custom_op::input_type::TENSOR, 2};
habana::custom_op::InputDesc input_d_desc{
habana::custom_op::input_type::USER_PARAMS, 3};
std::vector<habana::custom_op::InputDesc> inputs_desc{
input_a_desc, input_b_desc, input_c_desc, input_d_desc};
// output desc
// output shape callback
auto output_size_lambda =
[](const at::Stack& inputs) -> std::vector<int64_t> {
auto self = inputs[0].toTensor(); // input
std::vector<int64_t> result_sizes = self.sizes().vec();
return result_sizes;
};
habana::custom_op::OutputDesc output_desc{
0, c10::ScalarType::Float, output_size_lambda};
std::vector<habana::custom_op::OutputDesc> outputs_desc{
output_desc};
// user param callback
auto user_params_lambda = [](const at::Stack& inputs, size_t& size) {
HPU_PARAMS_STUB(SoftMaxParam);
params->with_bias = true;
int dim = inputs[3].toInt();
if (dim > 0)
params->axis = inputs[0].toTensor().dim() - dim - 1;
else
params->axis = - dim - 1;
return params;
};
// actual register
REGISTER_CUSTOM_OP_ATTRIBUTES(
"custom_op::fusedsoftmax_bias", //schema name
#ifdef GAUDI2
"fusedsoftmax_bias_fwd_f32_gaudi2", // guid
#else
"fusedsoftmax_bias_fwd_f32", // guid
#endif
inputs_desc,
outputs_desc,
user_params_lambda);
std::cout << "cpp registered custom_op::fusedsoftmax_bias\n";
return true;
}
bool register_custom_fusedsoftmax_backward() {
// inputs desc
habana::custom_op::InputDesc y_desc{
habana::custom_op::input_type::TENSOR, 0};
habana::custom_op::InputDesc grad_desc{
habana::custom_op::input_type::TENSOR, 1};
habana::custom_op::InputDesc dim_desc{
habana::custom_op::input_type::USER_PARAMS, 2};
std::vector<habana::custom_op::InputDesc> inputs_desc{
y_desc, grad_desc, dim_desc};
auto output_input_size_lambda =
[](const at::Stack& inputs) -> std::vector<int64_t> {
auto self = inputs[0].toTensor(); // input
std::vector<int64_t> result_sizes = self.sizes().vec();
return result_sizes;
};
habana::custom_op::OutputDesc input_grad_desc{
0, c10::ScalarType::Float, output_input_size_lambda};
std::vector<habana::custom_op::OutputDesc> outputs_desc{
input_grad_desc};
// user param callback
auto user_params_lambda = [](const at::Stack& inputs, size_t& size) {
HPU_PARAMS_STUB(ns_Softmax::Params);
params->dim = 0;
return params;
};
// actual register
REGISTER_CUSTOM_OP_ATTRIBUTES(
"custom_op::fusedsoftmax_backward", //schema name
#ifdef GAUDI2
"softmax_bwd_f32", // guid
#else
"softmax_bwd_f32", // guid
#endif
inputs_desc,
outputs_desc,
user_params_lambda);
std::cout << "cpp registered custom_op::fusedsoftmax_backward\n";
return true;
}
at::Tensor fusedsoftmax_execute(
torch::Tensor input,
torch::Tensor mask,
at::Scalar dim) {
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Float, "Input input_a expected to be Float tensor");
// Registering the custom op, need to be called only once
static bool registered = register_fusedsoftmax();
TORCH_CHECK(registered, "fusedsoftmax kernel not registered" );
std::vector<c10::IValue> inputs{input, mask, dim};
// Get custom op descriptor from registry
auto op_desc = habana::custom_op::HabanaCustomOpDescriptor::getCustomOpDescriptor("custom_op::fusedsoftmax");
// Actual call for op execution
std::vector<at::Tensor> output = op_desc.execute(inputs);
// op_desc.execute will always return a vector
return output[0];
}
at::Tensor fusedsoftmax_bias_execute(
torch::Tensor input,
torch::Tensor mask,
torch::Tensor bias,
at::Scalar dim) {
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Float, "Input input_a expected to be Float tensor");
// Registering the custom op, need to be called only once
static bool registered = register_fusedsoftmax_bias();
TORCH_CHECK(registered, "fusedsoftmax_bias kernel not registered" );
std::vector<c10::IValue> inputs{input, mask, bias, dim};
// Get custom op descriptor from registry
auto op_desc = habana::custom_op::HabanaCustomOpDescriptor::getCustomOpDescriptor("custom_op::fusedsoftmax_bias");
// Actual call for op execution
std::vector<at::Tensor> output = op_desc.execute(inputs);
// op_desc.execute will always return a vector
return output[0];
}
at::Tensor fusedsoftmax_backward_execute(
torch::Tensor y,
torch::Tensor grad,
at::Scalar dim) {
TORCH_CHECK(y.scalar_type() == c10::ScalarType::Float, "Input y expected to be Float tensor");
TORCH_CHECK(grad.scalar_type() == c10::ScalarType::Float, "Input grad expected to be Float tensor");
// Registering the custom op, need to be called only once
static bool registered = register_custom_fusedsoftmax_backward();
TORCH_CHECK(registered, "custom_fusedsoftmax_backward kernel not registered" );
std::vector<c10::IValue> inputs{y, grad, dim};
// Get custom op descriptor from registry
auto op_desc = habana::custom_op::HabanaCustomOpDescriptor::getCustomOpDescriptor("custom_op::fusedsoftmax_backward");
// Actual call for op execution
std::vector<at::Tensor> output = op_desc.execute(inputs);
// op_desc.execute will always return a vector
return output[0];
}
TORCH_LIBRARY(custom_op, m) {
m.def("fusedsoftmax(Tensor self, Tensor mask, Scalar dim) -> Tensor");
m.def("fusedsoftmax_bias(Tensor self, Tensor mask, Tensor bias, Scalar dim) -> Tensor");
m.def("fusedsoftmax_backward(Tensor y, Tensor grad, Scalar dim) -> Tensor");
}
TORCH_LIBRARY_IMPL(custom_op, HPU, m) {
m.impl("fusedsoftmax", fusedsoftmax_execute);
m.impl("fusedsoftmax_bias", fusedsoftmax_bias_execute);
m.impl("fusedsoftmax_backward", fusedsoftmax_backward_execute);
}
###############################################################################
# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company
###############################################################################
import torch
from fusedsoftmax import fused_softmax, fused_softmax_bias
def test_fusedsoftmax_op_function():
print(torch.ops.custom_op.fusedsoftmax)
print(torch.ops.custom_op.fusedsoftmax_bias)
# print(torch.ops.custom_op.custom_relu_backward)
input = torch.randn(1, 512, 4, 512, 512)
mask = torch.randn(1, 512, 1, 1, 512)
bias = torch.randn(1, 1, 4, 512, 512)
dim = -1
input_hpu = input.to('hpu')
mask_hpu = mask.to('hpu')
out = input + mask
output_cpu = torch.softmax(out, dim=dim)
output_hpu = fused_softmax(input_hpu, mask_hpu, dim)
assert((abs(output_hpu.cpu() - output_cpu) < 1e-6).all())
print("fused_softmax test passed")
input_hpu = input.to('hpu')
mask_hpu = mask.to('hpu')
bias_hpu = bias.to('hpu')
out = input + mask
out += bias
output_cpu = torch.softmax(out, dim=dim)
output_hpu = fused_softmax_bias(input_hpu, mask_hpu, bias_hpu, dim);
assert((abs(output_hpu.cpu() - output_cpu) < 1e-6).all())
print("fused_softmax_bias test passed")
test_fusedsoftmax_op_function()
def test_fusedsoftmax_bias_op_backward_function():
print("fused_softmax_bias_backward")
input = torch.randn(1, 512, 4, 512, 512, requires_grad=True)
mask = torch.randn(1, 512, 1, 1, 512, requires_grad=False)
bias = torch.randn(1, 1, 4, 512, 512, requires_grad=True)
dim = -1
# cpu reference
add_mask_cpu = input + mask
add_mask_cpu += bias
softmax_cpu = torch.softmax(add_mask_cpu, dim=dim)
input_hpu = input.to('hpu').detach()
input_hpu.requires_grad = True
mask_hpu = mask.to('hpu').detach()
mask_hpu.requires_grad = False
bias_hpu = bias.to('hpu').detach()
bias_hpu.requires_grad = True
softmax_hpu = fused_softmax_bias(input_hpu, mask_hpu, bias_hpu, dim)
assert((abs(softmax_hpu.detach().cpu() - softmax_cpu.detach()) < 1e-6).all())
grad_cpu = torch.ones_like(softmax_cpu)
softmax_cpu.backward(grad_cpu)
grad_hpu = grad_cpu.to('hpu')
softmax_hpu.backward(grad_hpu)
input_bwd_cpu = input.grad
input_bwd_hpu = input_hpu.grad
assert((abs(input_bwd_hpu.detach().cpu() - input_bwd_cpu.detach()) < 1e-6).all())
bias_bwd_cpu = bias.grad
bias_bwd_hpu = bias_hpu.grad
assert((abs(bias_bwd_hpu.detach().cpu() - bias_bwd_cpu.detach()) < 1e-6).all())
print("fused_softmax_bias_backward test passed")
test_fusedsoftmax_bias_op_backward_function()
def test_fusedsoftmax_op_backward_function():
print(torch.ops.custom_op.fusedsoftmax_backward)
input = torch.randn(1, 512, 4, 512, 512, requires_grad=True)
mask = torch.randn(1, 512, 1, 1, 512, requires_grad=False)
dim = -1
# cpu reference
add_mask_cpu = input + mask
softmax_cpu = torch.softmax(add_mask_cpu, dim=dim)
input_hpu = input.to('hpu').detach()
input_hpu.requires_grad = True
mask_hpu = mask.to('hpu').detach()
mask_hpu.requires_grad = False
softmax_hpu = fused_softmax(input_hpu, mask_hpu, dim)
assert((abs(softmax_hpu.detach().cpu() - softmax_cpu.detach()) < 1e-6).all())
grad_cpu = torch.ones_like(softmax_cpu)
softmax_cpu.backward(grad_cpu)
grad_hpu = grad_cpu.to('hpu')
softmax_hpu.backward(grad_hpu)
input_bwd_cpu = input.grad
input_bwd_hpu = input_hpu.grad
assert((abs(input_bwd_hpu.detach().cpu() - input_bwd_cpu.detach()) < 1e-6).all())
print("fused_softmax_backward test passed")
test_fusedsoftmax_op_backward_function()
###############################################################################
# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company
###############################################################################
from setuptools import setup
from torch.utils import cpp_extension
from habana_frameworks.torch.utils.lib_utils import get_include_dir, get_lib_dir
import os
import pybind11
torch_include_dir = get_include_dir()
torch_lib_dir = get_lib_dir()
habana_modules_directory = "/usr/include/habanalabs"
pybind_include_path = pybind11.get_include()
setup(name='hpu_fusedsoftmax',
ext_modules=[cpp_extension.CppExtension('hpu_fusedsoftmax', ['hpu_fusedsoftmax.cpp'],
language='c++', extra_compile_args=["-std=c++17"],
libraries=['habana_pytorch_plugin'],
library_dirs=[torch_lib_dir])],
include_dirs=[torch_include_dir,
habana_modules_directory,
pybind_include_path,
],
cmdclass={'build_ext': cpp_extension.BuildExtension})
###############################################################################
# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company
###############################################################################
from setuptools import setup
from torch.utils import cpp_extension
from habana_frameworks.torch.utils.lib_utils import get_include_dir, get_lib_dir
import os
import pybind11
torch_include_dir = get_include_dir()
torch_lib_dir = get_lib_dir()
habana_modules_directory = "/usr/include/habanalabs"
pybind_include_path = pybind11.get_include()
setup(name='hpu_fusedsoftmax',
ext_modules=[cpp_extension.CppExtension('hpu_fusedsoftmax', ['hpu_fusedsoftmax.cpp'],
language='c++', extra_compile_args=["-std=c++17"], define_macros=[("GAUDI2", None)],
libraries=['habana_pytorch_plugin'],
library_dirs=[torch_lib_dir])],
include_dirs=[torch_include_dir,
habana_modules_directory,
pybind_include_path,
],
cmdclass={'build_ext': cpp_extension.BuildExtension})
import math
import numpy as np
import torch.nn as nn
def glorot_uniform_af(x, gain=1.0):
"""
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
In PyTorch:
[feature_out, feature_in, n_head ...]
In Jax:
[... n_head, feature_in, feature_out]
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
[feature_in, n_head, feature_out]
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
"""
fan_in, fan_out = x.shape[-2:]
if len(x.shape) > 2:
receptive_field_size = np.prod(x.shape[:-2])
fan_in *= receptive_field_size
fan_out *= receptive_field_size
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
nn.init.uniform_(x, -dev, dev)
return x
import torch
import torch.nn.functional as F
def bias_sigmod_ele(y, bias, z):
return torch.sigmoid(y + bias) * z
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
out = (x + bias) * F.dropout(dropmask, p=prob, training=training)
out = residual + out
return out
def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor,
dropout_mask: torch.Tensor, Z_raw: torch.Tensor, prob: float,
training: bool) -> torch.Tensor:
return Z_raw + F.dropout(dropout_mask, p=prob, training=training) * (g * (ab + b))
\ No newline at end of file
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn import LayerNorm
from fastfold.habana.distributed import gather, row_to_col, scatter
from .kernel import bias_dropout_add
from .ops import GlobalAttention, SelfAttention, Transition
class MSAColumnGlobalAttention(nn.Module):
def __init__(self, d_node, c=8, n_head=8):
super(MSAColumnGlobalAttention, self).__init__()
self.d_node = d_node
self.c = c
self.n_head = n_head
self.layernormM = LayerNorm(d_node)
self.global_attention = GlobalAttention(qkv_dim=d_node, c=c, n_head=n_head, out_dim=d_node)
def forward(self, M_raw, M_mask):
M = M_raw.transpose(-2, -3)
M = self.layernormM(M)
M_mask = M_mask.transpose(-1, -2)
M = self.global_attention(M, M_mask)
M = M.transpose(-2, -3)
return M_raw + M
class MSARowAttentionWithPairBias(nn.Module):
def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15):
super(MSARowAttentionWithPairBias, self).__init__()
self.d_node = d_node
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernormM = LayerNorm(d_node)
self.layernormZ = LayerNorm(d_pair)
_init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
std=1.0 / math.sqrt(d_pair))
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True)
self.attention = SelfAttention(qkv_dim=d_node,
c=c,
n_head=n_head,
out_dim=d_node,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True)
def forward(self, M_raw, Z, M_mask):
## Input projections
M = self.layernormM(M_raw)
Z = self.layernormZ(Z)
b = F.linear(Z, self.linear_b_weights)
b = gather(b, dim=1)
b = rearrange(b, 'b q k h -> b h q k')
M = self.attention(M, M_mask, b)
dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype)
return bias_dropout_add(M,
self.out_bias,
dropout_mask,
M_raw,
prob=self.p_drop,
training=self.training)
class MSAColumnAttention(nn.Module):
def __init__(self, d_node, c=32, n_head=8):
super(MSAColumnAttention, self).__init__()
self.d_node = d_node
self.c = c
self.n_head = n_head
self.layernormM = LayerNorm(d_node)
self.attention = SelfAttention(qkv_dim=d_node,
c=c,
n_head=n_head,
out_dim=d_node,
gating=True)
def forward(self, M_raw, M_mask):
M = M_raw.transpose(-2, -3)
M = self.layernormM(M)
M_mask = M_mask.transpose(-1, -2)
M = self.attention(M, M_mask)
M = M.transpose(-2, -3)
return M_raw + M
class MSAStack(nn.Module):
def __init__(self, d_node, d_pair, p_drop=0.15):
super(MSAStack, self).__init__()
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node,
d_pair=d_pair,
p_drop=p_drop)
self.MSAColumnAttention = MSAColumnAttention(d_node=d_node)
self.MSATransition = Transition(d=d_node)
def forward(self, node, pair, node_mask):
node_mask_row = scatter(node_mask, dim=1)
node = self.MSARowAttentionWithPairBias(node, pair, node_mask_row)
node = row_to_col(node)
node_mask_col = scatter(node_mask, dim=2)
node = self.MSAColumnAttention(node, node_mask_col)
node = self.MSATransition(node)
return node
class ExtraMSACore(nn.Module):
def __init__(self, d_node, d_pair, p_drop=0.15):
super(ExtraMSACore, self).__init__()
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node,
d_pair=d_pair,
p_drop=p_drop,
c=8)
self.MSAColumnAttention = MSAColumnGlobalAttention(d_node=d_node, c=8)
self.MSATransition = Transition(d=d_node)
def forward(self, node, pair, node_mask):
node_mask_row = scatter(node_mask, dim=1)
node = self.MSARowAttentionWithPairBias(node, pair, node_mask_row)
node = row_to_col(node)
node_mask_col = scatter(node_mask, dim=2)
node = self.MSAColumnAttention(node, node_mask_col)
node = self.MSATransition(node)
return node
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment