Unverified Commit a2d32666 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[hotfix] make Gemini work for conv DNN (#1998)

parent 15589111
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor, ColoTensorSpec from colossalai.tensor import ColoTensor, ColoTensorSpec
from ._utils import GeneralTensor from colossalai.tensor.op_wrapper import colo_op_impl
from ._utils import GeneralTensor, convert_to_colo_tensor
def register_elementwise_op(op): def register_elementwise_op(op):
...@@ -15,16 +17,21 @@ def register_elementwise_op(op): ...@@ -15,16 +17,21 @@ def register_elementwise_op(op):
as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``. as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
This method computes on either a normal tensor or a sharded tensor. This method computes on either a normal tensor or a sharded tensor.
""" """
if 'inplace' in kwargs:
output = op(input_tensor, *args, **kwargs) # TODO(jiaruifang) inplace will cause bugs
if isinstance(input_tensor, ColoTensor): input_tensor = input_tensor.clone()
if isinstance(output, str): return op(input_tensor, *args, **kwargs)
return output else:
if not isinstance(output, torch.Tensor): output = op(input_tensor, *args, **kwargs)
raise NotImplementedError # return output
return ColoTensor.from_torch_tensor(output, if isinstance(input_tensor, ColoTensor):
spec=ColoTensorSpec(input_tensor.get_process_group(), if isinstance(output, str):
dist_attr=input_tensor.dist_spec)) return output
if not isinstance(output, torch.Tensor):
raise NotImplementedError
return ColoTensor.from_torch_tensor(output,
spec=ColoTensorSpec(input_tensor.get_process_group(),
dist_attr=input_tensor.dist_spec))
# Tensor op # Tensor op
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment