Unverified Commit 126ba573 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[Tensor] add layer norm Op (#852)

parent a82da26f
from .init import colo_uniform from .init import colo_uniform
from .linear import colo_linear from .linear import colo_linear
from .element_wise import colo_mean from .element_wise import colo_mean
from .layernorm import colo_layernorm
\ No newline at end of file
...@@ -5,8 +5,10 @@ from colossalai.tensor import ColoTensor ...@@ -5,8 +5,10 @@ from colossalai.tensor import ColoTensor
@colo_op_impl(torch.mean) @colo_op_impl(torch.mean)
def colo_mean(types, args=(), kwargs=None, pg=None): def colo_mean(types, args=(), kwargs=None, pg=None):
stateful_tensor = args[0] input_t = args[0]
return torch.mean(stateful_tensor.torch_tensor()) if isinstance(input_t, ColoTensor):
input_t = input_t.torch_tensor()
return ColoTensor.init_from_torch_tensor(torch.mean(input_t))
def register_elementwise_op(op): def register_elementwise_op(op):
...@@ -22,7 +24,7 @@ def register_elementwise_op(op): ...@@ -22,7 +24,7 @@ def register_elementwise_op(op):
# Validate types # Validate types
if not isinstance(input_tensor, ColoTensor): if not isinstance(input_tensor, ColoTensor):
raise TypeError("input needs to be a ColoTensor") raise TypeError("input needs to be a ColoTensor")
return op(input_tensor.torch_tensor()) return ColoTensor.init_from_torch_tensor(op(input_tensor.torch_tensor()))
register_elementwise_op(torch.nn.functional.gelu) register_elementwise_op(torch.nn.functional.gelu)
......
from numpy import isin, kaiser
import torch
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor
@colo_op_impl(torch.nn.functional.layer_norm)
def colo_layernorm(types, args=(), kwargs=None, pg=None):
arg_num = len(args)
if arg_num > 0:
input_tensor = args[0]
if arg_num > 1:
normalized_shape = args[1]
if arg_num > 2:
weight = args[3]
if arg_num > 3:
bias = args[4]
if arg_num > 4:
eps = args[5]
if 'input' in kwargs:
input_tensor = kwargs['input']
if 'weight' in kwargs:
weight = kwargs['weight']
if 'bias' in kwargs:
bias = kwargs['bias']
if 'eps' in kwargs:
eps = kwargs['eps']
if isinstance(input_tensor, ColoTensor):
input_tensor = input_tensor.torch_tensor()
if isinstance(weight, ColoTensor):
weight = weight.torch_tensor()
if isinstance(bias, ColoTensor):
bias = bias.torch_tensor()
return ColoTensor.init_from_torch_tensor(
torch.nn.functional.layer_norm(input_tensor, normalized_shape, weight, bias, eps))
...@@ -8,6 +8,7 @@ from colossalai.context import ParallelMode ...@@ -8,6 +8,7 @@ from colossalai.context import ParallelMode
from colossalai.nn.layer.utils import divide from colossalai.nn.layer.utils import divide
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
class ColoTensor(object): class ColoTensor(object):
""" Data Structure for Tensor in Colossal-AI """ Data Structure for Tensor in Colossal-AI
1. It contains a torch.Tensor as an attribute. 1. It contains a torch.Tensor as an attribute.
...@@ -145,3 +146,6 @@ class ColoTensor(object): ...@@ -145,3 +146,6 @@ class ColoTensor(object):
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()} kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
return func(*args, **kwargs) return func(*args, **kwargs)
def backward(self, retain_graph: bool = False):
self._torch_tensor.backward(retain_graph=retain_graph)
from numpy import allclose, require from numpy import allclose
import torch import torch
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoTensor
from copy import deepcopy from copy import deepcopy
from colossalai.utils import get_current_device
def test_layernorm():
ln_op = torch.nn.LayerNorm(2, 3, device=get_current_device())
ln_op_colo = deepcopy(ln_op)
input_t = torch.randn(3, 2, device=get_current_device())
input_t_colo = ColoTensor.init_from_torch_tensor(tensor=input_t.clone().detach())
# prepare colossalai LN
delattr(ln_op_colo, 'weight')
weight_clone = ln_op.weight.clone().detach()
weight_clone.requires_grad = True
setattr(ln_op_colo, 'weight', ColoTensor.init_from_torch_tensor(tensor=weight_clone))
output = ln_op(input_t)
output_colo = ln_op_colo(input_t_colo)
assert allclose(output_colo.torch_tensor().detach().cpu(), output.detach().cpu())
torch.mean(output).backward()
torch.mean(output_colo).backward()
assert allclose(ln_op.weight.grad.cpu(), ln_op_colo.weight.torch_tensor().grad.cpu())
def test_linear(): def test_linear():
...@@ -50,8 +75,8 @@ def test_element_wise(): ...@@ -50,8 +75,8 @@ def test_element_wise():
t_ref = torch.randn(3, 5) t_ref = torch.randn(3, 5)
t = ColoTensor.init_from_torch_tensor(t_ref.clone()) t = ColoTensor.init_from_torch_tensor(t_ref.clone())
assert torch.mean(t) == torch.mean(t_ref) assert torch.mean(t) == torch.mean(t_ref)
assert allclose(torch.nn.functional.gelu(t), torch.nn.functional.gelu(t_ref)) assert allclose(torch.nn.functional.gelu(t).torch_tensor(), torch.nn.functional.gelu(t_ref))
assert allclose(torch.nn.functional.relu(t), torch.nn.functional.relu(t_ref)) assert allclose(torch.nn.functional.relu(t).torch_tensor(), torch.nn.functional.relu(t_ref))
# Test a function not wrapped by # Test a function not wrapped by
...@@ -76,4 +101,5 @@ def check_all(): ...@@ -76,4 +101,5 @@ def check_all():
if __name__ == '__main__': if __name__ == '__main__':
test_lazy_init_tensor() # test_lazy_init_ptensor()
test_layernorm()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment