Unverified Commit b0b53a17 authored by digger yu's avatar digger yu Committed by GitHub
Browse files

[nfc] fix typo colossalai/shardformer/ (#5133)

parent 451e9142
...@@ -79,9 +79,9 @@ Following are the description `ShardConfig`'s arguments: ...@@ -79,9 +79,9 @@ Following are the description `ShardConfig`'s arguments:
- `enable_sequence_overlap`: Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when `enable_sequence_parallelism` is True. Defaults to False. - `enable_sequence_overlap`: Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when `enable_sequence_parallelism` is True. Defaults to False.
- `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalizaion`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False. - `enable_all_optimization`: Whether to turn on all optimization tools including `fused normalization`, `flash attention`, `JIT fused operators`, `sequence parallelism` and `sequence overlap`. Defaults to False.
- `extra_kwargs`: A dict to store extra kwargs for ShardFomer. - `extra_kwargs`: A dict to store extra kwargs for ShardFormer.
### Write your own policy ### Write your own policy
......
...@@ -32,7 +32,7 @@ def set_obj_list_element(obj, attr: str, value): ...@@ -32,7 +32,7 @@ def set_obj_list_element(obj, attr: str, value):
r""" r"""
Set the element to value of a list object Set the element to value of a list object
It used like set_obj_list_element(obj, 'lyaers[0]', new_layer), it will set obj.layers[0] to value It used like set_obj_list_element(obj, 'layers[0]', new_layer), it will set obj.layers[0] to value
Args: Args:
obj (object): The object to set obj (object): The object to set
......
...@@ -22,8 +22,8 @@ class ShardConfig: ...@@ -22,8 +22,8 @@ class ShardConfig:
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False. enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False. enable_jit_fused (bool, optional): Whether to switch on JIT fused operators. Defaults to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, which partitions non-tensor-parallel regions along the sequence dimension. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, wheich overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False. enable_sequence_overlap (bool): Whether to turn on sequence overlap, which overlap the computation and communication in sequence parallelism. It can only be used when enable_sequence_parallelism is True. Defaults to False.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalizaion', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
""" """
tensor_parallel_process_group: Optional[ProcessGroup] = None tensor_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None pipeline_stage_manager: Optional[PipelineStageManager] = None
......
...@@ -37,7 +37,7 @@ class ModelSharder(object): ...@@ -37,7 +37,7 @@ class ModelSharder(object):
self.policy.set_model(self.model) self.policy.set_model(self.model)
self.policy.set_shard_config(self.shard_config) self.policy.set_shard_config(self.shard_config)
self._preprocess() self._preprocess()
# get shared params before release unheld layers, this avoid misjudgement of shared params (None is None) # get shared params before release unheld layers, this avoid misjudgment of shared params (None is None)
shared_params = self.policy.get_shared_params() shared_params = self.policy.get_shared_params()
held_layers = self._release_unheld_layers() held_layers = self._release_unheld_layers()
self._replace_module(include=held_layers) self._replace_module(include=held_layers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment