"python-wheel/examples/bls/bls.py" did not exist on "c9130f8f8ce264379131e9ee2973534fe4cbf713"
Commit b0b8ad28 authored by ver217's avatar ver217 Committed by Hongxin Liu
Browse files

[pipeline] update shardformer docstring

parent 59f6f573
from typing import Dict, List, Tuple
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
...@@ -24,7 +27,7 @@ class ShardFormer: ...@@ -24,7 +27,7 @@ class ShardFormer:
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
shard_config = ShardConfig() shard_config = ShardConfig()
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)
model = shard_former.optimize(org_model) model, shared_params = shard_former.optimize(org_model)
``` ```
""" """
...@@ -32,7 +35,7 @@ class ShardFormer: ...@@ -32,7 +35,7 @@ class ShardFormer:
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
self.shard_config = shard_config self.shard_config = shard_config
def optimize(self, model: nn.Module, policy: Policy = None): def optimize(self, model: nn.Module, policy: Policy = None) -> Tuple[nn.Module, List[Dict[int, Tensor]]]:
r""" r"""
This method will optimize the model based on the given policy. This method will optimize the model based on the given policy.
...@@ -40,6 +43,8 @@ class ShardFormer: ...@@ -40,6 +43,8 @@ class ShardFormer:
model (`torch.nn.Model`): the origin huggingface model model (`torch.nn.Model`): the origin huggingface model
shard_config (`ShardConfig`): the config for distribute information shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding policy (`Policy`): the custom policy for sharding
Returns: the sharded model and the shared parameters
""" """
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy) sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
shared_params = sharder.shard() shared_params = sharder.shard()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment