"examples/images/diffusion/vscode:/vscode.git/clone" did not exist on "b5dbb4617260875ea2835b01232acb8da5dcd2ca"
Unverified Commit 24c07687 authored by Jianghai's avatar Jianghai Committed by GitHub
Browse files

[shardformer] Pytree fix (#4533)

* pytree test

* test bert

* test bert

* test bert

* revise

* add register

* add register
parent 508ca36f
from typing import Any, List, Optional
from collections import OrderedDict
from typing import Any, List, Optional, Tuple
import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from torch.utils._pytree import (
SUPPORTED_NODES,
LeafSpec,
TreeSpec,
_is_leaf,
_register_pytree_node,
tree_flatten,
tree_map,
tree_unflatten,
)
# this register are for torch under version 1.13.1, maybe removed in the future
def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Any]:
return list(d.values()), list(d.keys())
def _odict_unflatten(values: List[Any], context: Any) -> 'OrderedDict[Any, Any]':
return OrderedDict((key, value) for key, value in zip(context, values))
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
def tree_map_hf(fn: Any, pytree: Any):
flat_args, spec = tree_flatten_hf(pytree)
return tree_unflatten([fn(i) for i in flat_args], spec)
# use this flatten function to handle the ModelingOutput Class instance.
def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]:
"""Flattens a pytree into a list of values an a TreeSpec that can be used
to reconstruct the pytree.
"""
if isinstance(pytree, OrderedDict):
node_type = OrderedDict
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, context = flatten_fn(pytree)
# Recursively flatten the children
result: List[Any] = []
children_specs: List['TreeSpec'] = []
for child in child_pytrees:
flat, child_spec = tree_flatten_hf(child)
result += flat
children_specs.append(child_spec)
return result, TreeSpec(node_type, context, children_specs)
else:
result, tree_spec = tree_flatten(pytree)
return result, tree_spec
def to_device(x: Any, device: Optional[torch.device] = None) -> Any:
......@@ -104,7 +154,7 @@ def detach(x: Any) -> Any:
return x
def merge_batch(data: List[Any]) -> Any:
def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
"""Merge micro batches into a batch.
Args:
......@@ -118,15 +168,17 @@ def merge_batch(data: List[Any]) -> Any:
flattened_data = []
tree_spec = None
for d in data:
elems, tree_spec = tree_flatten(d)
# elems should be an instance of OrderedDict
elems, tree_spec = tree_flatten_hf(d)
flattened_data.append(elems)
merged_data = []
for elem_batch in zip(*flattened_data):
if isinstance(elem_batch[0], torch.Tensor):
if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs
merged_data.append(None)
else:
merged_data.append(torch.cat(elem_batch, dim=0))
merged_data.append(torch.cat(elem_batch, dim=batch_size_dim))
else:
merged_data.append(list(elem_batch))
return tree_unflatten(merged_data, tree_spec)
......@@ -6,12 +6,21 @@ import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from ._utils import (
detach,
get_batch_size,
get_micro_batch,
merge_batch,
model_forward,
retain_grad,
to_device,
tree_map_hf,
)
from .base import PipelineSchedule
......@@ -154,7 +163,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
outputs.append(tree_map_hf(detach, output_obj))
return loss
else:
return output_obj
......@@ -302,5 +311,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.send_backward(input_obj_grad)
if outputs is not None:
outputs = merge_batch(outputs)
if isinstance(model, ModelWrapper):
model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, 'batch_size_dim', 0))
return {'loss': accum_loss, 'outputs': outputs}
......@@ -41,6 +41,11 @@ class ChatGLMPolicy(Policy):
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
if self.pipeline_stage_manager is not None:
# the batch_size_dim is bounded to Model
bsz_dim = 1
setattr(self.model, 'batch_size_dim', bsz_dim)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
......
......@@ -191,15 +191,10 @@ def check_output_hidden_state(org_output: Tensor,
org_hidden_state = org_output.last_hidden_state
if stage_manager is None:
sharded_hidden_state = sharded_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage():
pipeline_output = sharded_output['outputs']
if isinstance(pipeline_output, List):
sharded_hidden_state = torch.cat([output.last_hidden_state for output in pipeline_output], dim=dim)
sharded_hidden_state = sharded_output['outputs']['last_hidden_state']
else:
sharded_hidden_state = pipeline_output.last_hidden_state
sharded_hidden_state = sharded_output.last_hidden_state
assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
......
......@@ -179,6 +179,7 @@ def run_bert_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment