Unverified Commit 4df01949 authored by Ziyue Jiang's avatar Ziyue Jiang Committed by GitHub
Browse files

[Pipeline]Adapt to Pipelinable OPT (#1782)

parent 27de2523
...@@ -6,6 +6,7 @@ from colossalai.logging import get_dist_logger ...@@ -6,6 +6,7 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.layer.utils import CheckpointModule from colossalai.nn.layer.utils import CheckpointModule
from typing import List from typing import List
from collections import OrderedDict
def _binary_partition(weights: List, start: int, end: int): def _binary_partition(weights: List, start: int, end: int):
"""Returns the binary partition position of `weights`, given the start """Returns the binary partition position of `weights`, given the start
...@@ -159,8 +160,10 @@ def build_kwargs_for_module(function, input_tensor, kw_dict): ...@@ -159,8 +160,10 @@ def build_kwargs_for_module(function, input_tensor, kw_dict):
kwargs_offset = 0 kwargs_offset = 0
elif isinstance(input_tensor, torch.Tensor): elif isinstance(input_tensor, torch.Tensor):
kwargs_offset = 1 kwargs_offset = 1
else: elif isinstance(input_tensor, (tuple, OrderedDict)):
assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.' #assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
# Huggingface will take their own structures based on OrderedDict as the output
# between layers so we've to close this check.
kwargs_offset = len(input_tensor) kwargs_offset = len(input_tensor)
args_name_list = list(sig.parameters.keys()) args_name_list = list(sig.parameters.keys())
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[kwargs_offset:]} kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[kwargs_offset:]}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment