"examples/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "2340751bf6308f04b81aaf0a287a94cac2350a52"
Unverified Commit a0ad587c authored by flybird11111's avatar flybird11111 Committed by GitHub
Browse files

[shardformer] refactor embedding resize (#5603)



* [branch rebase] rebase main to Feature/resize_embedding (#5554)

* fix

* [release] update version (#5411)

* [hotfix] fix typo s/keywrods/keywords etc. (#5429)

* [devops] fix compatibility (#5444)

* [devops] fix compatibility

* [hotfix] update compatibility test on pr

* [devops] fix compatibility

* [devops] record duration during comp test

* [test] decrease test duration

* fix falcon

* [shardformer] fix gathering output when using tensor parallelism (#5431)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* [doc] release Open-Sora 1.0 with model weights (#5468)

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] update open-sora demo (#5479)

* [doc] update open-sora demo

* [doc] update open-sora demo

* [doc] update open-sora demo

* [example] add grok-1 inference (#5485)

* [misc] add submodule

* remove submodule

* [example] support grok-1 tp inference

* [example] add grok-1 inference script

* [example] refactor code

* [example] add grok-1 readme

* [exmaple] add test ci

* [exmaple] update readme

---------
Co-authored-by: default avatarHongxin Liu <lhx0217@gmail.com>
Co-authored-by: default avatardigger yu <digger-yu@outlook.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

* [CI] run pre-commit (#5577)

* fix

* [release] update version (#5411)

* [hotfix] fix typo s/keywrods/keywords etc. (#5429)

* [devops] fix compatibility (#5444)

* [devops] fix compatibility

* [hotfix] update compatibility test on pr

* [devops] fix compatibility

* [devops] record duration during comp test

* [test] decrease test duration

* fix falcon

* [shardformer] fix gathering output when using tensor parallelism (#5431)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* [doc] release Open-Sora 1.0 with model weights (#5468)

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] update open-sora demo (#5479)

* [doc] update open-sora demo

* [doc] update open-sora demo

* [doc] update open-sora demo

* [example] add grok-1 inference (#5485)

* [misc] add submodule

* remove submodule

* [example] support grok-1 tp inference

* [example] add grok-1 inference script

* [example] refactor code

* [example] add grok-1 readme

* [exmaple] add test ci

* [exmaple] update readme

* run pre-commit

---------
Co-authored-by: default avatarHongxin Liu <lhx0217@gmail.com>
Co-authored-by: default avatardigger yu <digger-yu@outlook.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

* [rebase] rebase main to resize-embedding (#5581)

* [release] grok-1 314b inference (#5490)

* [release] grok-1 inference

* [release] grok-1 inference

* [release] grok-1 inference

* [example] update Grok-1 inference (#5495)

* revise grok-1 example

* remove unused arg in scripts

* prevent re-installing torch

* update readme

* revert modifying colossalai requirements

* add perf

* trivial

* add tokenizer url

* [hotfix] set return_outputs=False in examples and polish code (#5404)

* fix: simplify merge_batch

* fix: use return_outputs=False to eliminate extra memory consumption

* feat: add return_outputs warning

* style: remove `return_outputs=False` as it is the default value

* [release] grok-1 inference benchmark (#5500)

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [shardformer]Fix lm parallel. (#5480)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* fix lm forward distribution

* fix

* test ci

* fix

* [fix] fix grok-1 example typo (#5506)

* [devops] fix example test ci (#5504)

* Fix ColoTensorSpec for py11 (#5440)

* fixed layout converter caching and updated tester

* Empty-Commit

* [shardformer] update colo attention to support custom mask (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests

* [format] applied code formatting on changed files in pull request 5510 (#5517)
Co-authored-by: default avatargithub-actions <github-actions@github.com>

* [shardformer] fix pipeline forward error if custom layer distribution is used (#5189)

* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution

* Change static methods for t5 layer distribution to member functions

* Change static methods for whisper layer distribution to member functions

* Replace whisper policy usage with self one

* Fix test case to use non-static layer distribution methods

* fix: fix typo

---------
Co-authored-by: default avatarWenhao Chen <cwher@outlook.com>

* [Fix] Grok-1 use tokenizer from the same pretrained path (#5532)

* [fix] use tokenizer from the same pretrained path

* trust remote code

* [ColossalChat] Update RLHF V2 (#5286)

* Add dpo. Fix sft, ppo, lora. Refactor all

* fix and tested ppo

* 2 nd round refactor

* add ci tests

* fix ci

* fix ci

* fix readme, style

* fix readme style

* fix style, fix benchmark

* reproduce benchmark result, remove useless files

* rename to ColossalChat

* use new image

* fix ci workflow

* fix ci

* use local model/tokenizer for ci tests

* fix ci

* fix ci

* fix ci

* fix ci timeout

* fix rm progress bar. fix ci timeout

* fix ci

* fix ci typo

* remove 3d plugin from ci temporary

* test environment

* cannot save optimizer

* support chat template

* fix readme

* fix path

* test ci locally

* restore build_or_pr

* fix ci data path

* fix benchmark

* fix ci, move ci tests to 3080, disable fast tokenizer

* move ci to 85

* support flash attention 2

* add all-in-one data preparation script. Fix colossal-llama2-chat chat template

* add hardware requirements

* move ci test data

* fix save_model, add unwrap

* fix missing bos

* fix missing bos; support grad accumulation with gemini

* fix ci

* fix ci

* fix ci

* fix llama2 chat template config

* debug sft

* debug sft

* fix colossalai version requirement

* fix ci

* add sanity check to prevent NaN loss

* fix requirements

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* update readme

* update readme

* update readme and ignore

* fix logger bug

* support parallel_output

* modify data preparation logic

* fix tokenization

* update lr

* fix inference

* run pre-commit

---------
Co-authored-by: default avatarTong Li <tong.li352711588@gmail.com>

* [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508)

* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`

* feat: apply `GradientCheckpointConfig` to policy and llama_forward

* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager

* fix: add optional args for `distribute_layer` and `get_stage_index`

* fix: fix changed API calls

* test: update llama tests

* style: polish `GradientCheckpointConfig`

* fix: fix pipeline utils tests

* fix incorrect sharding without zero (#5545)
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>

* [shardformer] Sequence Parallelism Optimization (#5533)

* sequence parallel optimization

* validate sequence parallel in llama (code to be polished)

* shardformer api writing

* integrate sequence parallel in ShardFormer

* fix pp bugs and sp bugs for LlaMa model

* integrating ring-based sequence parallelism into ShardFormer

* [sequence parallelism]: Add fused megatron function

* integrating ring-based sequence parallelism into ShardFormer

---------
Co-authored-by: default avatarlinsj20 <linsj20@mails.tsinghua.edu.cn>

* fix bugs when useing sp and flashattention together

* fix operation function name

* support flash attention for ulysses-style sp

* clarify sp process group

* fix compatibility bugs in moe plugin

* fix fused linear bugs

* fix linear layer test

* support gpt model all-to-all sp

* modify shard data dimension (meant to be dim=-1)

* support megtron-style sp and distributed attn for llama model

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* finish sp mode 3 support for gpt

* using all_to_all_single when batch size is 1

* support mode 2 sp in gpt2 (#5)

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* refactor ring implementation

* support mode 2 sp in gpt2

* polish code

* enable distributed attn mask when using sp mode 2 and 3 in llama

* automatically enable flash attn when using sp mode 2 and 3 in llama

* inplace attn mask

* add zero2 support for sequence parallel

* polish code

* fix bugs

* fix gemini checkpoint io

* loose tensor checking atol and rtol

* add comment

* fix llama layernorm grad

* fix zero grad

* fix zero grad

* fix conflict

* update split and gather auto grad func

* sequence parallel: inside text split (#6)

* polish code (part 1)

* polish code (part 2)

* polish code (part 2.5)

* polish code (part 3)

* sequence parallel: inside text split

* miscellaneous minor fixes

* polish code

* fix ulysses style ZeRO

* sequence parallel: inside text split

* miscellaneous minor fixes

* disaggregate sp group and dp group for  sp

* fix llama and gpt sp

* polish code

* move ulysses grad sync to ddp (#9)

* remove zero_stage and unbind the grad sync for alltoall sp

* add 2d group creation test

* move ulysses grad sync to ddp

* add 2d group creation test

* remove useless code

* change shard config not to enable sp when enable_all_optimizations

* add sp warnings for several model

* remove useless code

---------
Co-authored-by: default avatarlinsj20 <linsj20@mails.tsinghua.edu.cn>

* [hotfix] quick fixes to make legacy tutorials runnable (#5559)
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>

* [fix] fix typo s/muiti-node /multi-node etc. (#5448)

* [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548)

* [devops] remove post commit ci (#5566)

* [devops] remove post commit ci

* [misc] run pre-commit on all files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

---------
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>
Co-authored-by: default avatarYuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: default avatarWenhao Chen <cwher@outlook.com>
Co-authored-by: default avatarHongxin Liu <lhx0217@gmail.com>
Co-authored-by: default avatarRocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
Co-authored-by: default avatarEdenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: default avatargithub-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatargithub-actions <github-actions@github.com>
Co-authored-by: default avatarInsu Jang <insujang@umich.edu>
Co-authored-by: default avatarYeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: default avatarTong Li <tong.li352711588@gmail.com>
Co-authored-by: default avatarZhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: default avatarlinsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: default avatardigger yu <digger-yu@outlook.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [shardformer]enable padding vocabulary size. (#5489)

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* padding vocab

* padding vocabe

* fix

* fix

* fxi

* test ci

* fix

fix

fix

fix

* fix

fix

* fix

* fix

* Update hybrid_parallel_plugin.py

fix

fix

fix

* fix

fix

* fix

fix

* fix

* resolve super init

resolve super init

resolve super init

resolve super init

* resolve comments

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* vocab checkpointio

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

fix

fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* padding vocab

* fix

* fix

fix

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* cherry-pick

* revert moe modify

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix

fix

fix

fix

fix

fix

fix

fix

* resolve comments

resolve comments

resolve comments

resolve comments

resolve comments

* ptensor

ptensor

resolve comments

fix

fix

fix

fix

fix

resolve comments

resolve comments

resolve comments

resolve comments

resolve comments

---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarHongxin Liu <lhx0217@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix rebase

* fix rebase

---------
Co-authored-by: default avatarHongxin Liu <lhx0217@gmail.com>
Co-authored-by: default avatardigger yu <digger-yu@outlook.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>
Co-authored-by: default avatarYuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: default avatarWenhao Chen <cwher@outlook.com>
Co-authored-by: default avatarRocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
Co-authored-by: default avatarEdenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: default avatargithub-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatargithub-actions <github-actions@github.com>
Co-authored-by: default avatarInsu Jang <insujang@umich.edu>
Co-authored-by: default avatarYeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: default avatarTong Li <tong.li352711588@gmail.com>
Co-authored-by: default avatarZhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: default avatarlinsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 3788fefc
...@@ -44,10 +44,10 @@ ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2 ...@@ -44,10 +44,10 @@ ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
def get_param_info(optim: Optimizer): def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes: # Get a backup of necessary information of parameters for future use, which includes:
# 1. A mapping from integer param_id to param32 shape. # 1. A mapping from integer param_id to param32 shape.
if optim is None: if optim is None:
return {} return {}
param_info = {"id2shape": {}} param_info = {"id2shape": {}}
start_index = 0 start_index = 0
for group in optim.param_groups: for group in optim.param_groups:
for param_id, param in enumerate(group["params"], start_index): for param_id, param in enumerate(group["params"], start_index):
...@@ -527,7 +527,7 @@ class GeminiPlugin(DPPluginBase): ...@@ -527,7 +527,7 @@ class GeminiPlugin(DPPluginBase):
dataloader: Optional[DataLoader] = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None, lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
optimizer_params_info = get_param_info(optimizer) params_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
# convert model to sync bn # convert model to sync bn
# FIXME(ver217): gemini does not support sync bn # FIXME(ver217): gemini does not support sync bn
...@@ -558,7 +558,7 @@ class GeminiPlugin(DPPluginBase): ...@@ -558,7 +558,7 @@ class GeminiPlugin(DPPluginBase):
**self.zero_optim_config, **self.zero_optim_config,
**self.optim_kwargs, **self.optim_kwargs,
tp_group=self.tp_group, tp_group=self.tp_group,
optimizer_params_info=optimizer_params_info, params_info=params_info,
verbose=self.verbose, verbose=self.verbose,
) )
......
...@@ -213,12 +213,7 @@ def get_param_info(optim: Optimizer): ...@@ -213,12 +213,7 @@ def get_param_info(optim: Optimizer):
if optim is None: if optim is None:
return {} return {}
param_info = { param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
"param_groups": [],
"param2id": {},
"id2param": {},
"param2shape": {},
}
start_index = 0 start_index = 0
for group in optim.param_groups: for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != "params"} packed_group = {k: v for k, v in group.items() if k != "params"}
...@@ -947,6 +942,8 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -947,6 +942,8 @@ class HybridParallelPlugin(PipelinePluginBase):
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None. gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
""" """
def __init__( def __init__(
...@@ -989,6 +986,7 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -989,6 +986,7 @@ class HybridParallelPlugin(PipelinePluginBase):
num_model_chunks: int = 1, num_model_chunks: int = 1,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True, enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 64,
) -> None: ) -> None:
super().__init__() super().__init__()
assert ( assert (
...@@ -1095,6 +1093,7 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -1095,6 +1093,7 @@ class HybridParallelPlugin(PipelinePluginBase):
sequence_parallelism_mode=sequence_parallelism_mode, sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap, enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output, parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config, gradient_checkpoint_config=gradient_checkpoint_config,
) )
self.amp_config = dict( self.amp_config = dict(
......
...@@ -14,6 +14,12 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler ...@@ -14,6 +14,12 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.padded_tensor import (
init_as_padded_tensor,
is_padded_tensor,
to_padded_tensor,
to_unpadded_tensor,
)
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from .general_checkpoint_io import GeneralCheckpointIO from .general_checkpoint_io import GeneralCheckpointIO
...@@ -32,6 +38,7 @@ from .utils import ( ...@@ -32,6 +38,7 @@ from .utils import (
save_param_groups, save_param_groups,
save_state_dict, save_state_dict,
save_state_dict_shards, save_state_dict_shards,
search_padding_dim,
search_tp_partition_dim, search_tp_partition_dim,
sharded_optimizer_loading_epilogue, sharded_optimizer_loading_epilogue,
) )
...@@ -89,6 +96,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -89,6 +96,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if param is None: if param is None:
continue continue
# Gather tensor pieces when using tensor parallel. # Gather tensor pieces when using tensor parallel.
if is_padded_tensor(param):
param = to_unpadded_tensor(param)
param_ = gather_distributed_param(param, keep_vars=False) param_ = gather_distributed_param(param, keep_vars=False)
block, block_size = state_dict_sharder.append_param(prefix + name, param_) block, block_size = state_dict_sharder.append_param(prefix + name, param_)
if block is not None: if block is not None:
...@@ -231,7 +240,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -231,7 +240,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# When pipeline is used, each stage produces its own shard files and index files. # When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder. # After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path = copy.deepcopy(save_index_file) final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files") tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
...@@ -251,6 +259,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -251,6 +259,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
use_pp_format=True, use_pp_format=True,
) )
if control_saving: if control_saving:
assert ( assert (
self.dp_rank == 0 and self.tp_rank == 0 self.dp_rank == 0 and self.tp_rank == 0
...@@ -867,6 +876,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -867,6 +876,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
dist.all_gather(gather_tensor, v, group=tp_group) dist.all_gather(gather_tensor, v, group=tp_group)
v = torch.cat(gather_tensor, dim=partition_dim) v = torch.cat(gather_tensor, dim=partition_dim)
padding_dim = search_padding_dim(v.shape, original_shape)
if padding_dim is not None:
v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
v = to_unpadded_tensor(v)
state_[k] = v.detach().clone().to(device) state_[k] = v.detach().clone().to(device)
return state_ return state_
...@@ -899,6 +913,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ...@@ -899,6 +913,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if isinstance(v, torch.Tensor) and k != "step": if isinstance(v, torch.Tensor) and k != "step":
# Shard state along tensor parallel group. # Shard state along tensor parallel group.
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size) partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
global_shape = current_shape
if partition_dim is not None:
# pad embedding params
global_shape = (
*current_shape[:partition_dim],
current_shape[partition_dim] * self.tp_size,
*current_shape[partition_dim + 1 :],
)
padding_dim = search_padding_dim(global_shape, original_shape)
if padding_dim is not None:
v = to_padded_tensor(v, global_shape[padding_dim], padding_dim)
if partition_dim is not None: if partition_dim is not None:
slice_size = current_shape[partition_dim] slice_size = current_shape[partition_dim]
v = v.split(slice_size, dim=partition_dim)[self.tp_rank] v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
......
...@@ -120,6 +120,15 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz ...@@ -120,6 +120,15 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
return partition_dim return partition_dim
def search_padding_dim(global_shape: torch.Size, original_shape: torch.Size) -> Optional[int]:
padding_dim = None
for dim, length in enumerate(global_shape):
if length > original_shape[dim]:
padding_dim = dim
break
return padding_dim
# ====================================== # ======================================
# Helper classes and functions for saving shard file # Helper classes and functions for saving shard file
# ====================================== # ======================================
......
from ._operation import all_to_all_comm from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention from .attn import AttnMaskType, ColoAttention
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule from .parallel_module import ParallelModule
...@@ -25,6 +25,9 @@ __all__ = [ ...@@ -25,6 +25,9 @@ __all__ = [
"FusedRMSNorm", "FusedRMSNorm",
"FusedLinear1D_Col", "FusedLinear1D_Col",
"ParallelModule", "ParallelModule",
"PaddingEmbedding",
"PaddingLMHead",
"VocabParallelLMHead1D",
"AttnMaskType", "AttnMaskType",
"ColoAttention", "ColoAttention",
"all_to_all_comm", "all_to_all_comm",
......
...@@ -21,10 +21,10 @@ from colossalai.tensor.d_tensor.api import ( ...@@ -21,10 +21,10 @@ from colossalai.tensor.d_tensor.api import (
) )
from ._operation import gather_forward_split_backward, reduce_forward from ._operation import gather_forward_split_backward, reduce_forward
from .parallel_module import ParallelModule from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset from .utils import create_randomizer_with_offset
__all__ = ["Embedding1D", "VocabParallelEmbedding1D"] __all__ = ["Embedding1D", "VocabParallelEmbedding1D", "PaddingEmbedding"]
class Embedding1D(ParallelModule): class Embedding1D(ParallelModule):
...@@ -161,7 +161,80 @@ class Embedding1D(ParallelModule): ...@@ -161,7 +161,80 @@ class Embedding1D(ParallelModule):
return output_parallel return output_parallel
class VocabParallelEmbedding1D(ParallelModule): class PaddingEmbedding(PaddingParallelModule):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
device: torch.device = None,
weight: Optional[nn.Parameter] = None,
make_vocab_size_divisible_by: int = 64,
*args,
**kwargs,
):
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.embed_args = args
self.embed_kwargs = kwargs
self.padding_idx = padding_idx
if num_embeddings % make_vocab_size_divisible_by != 0:
self.num_embeddings = (
num_embeddings + make_vocab_size_divisible_by - (num_embeddings % make_vocab_size_divisible_by)
)
# create weight and bias
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
super().__init__(self.num_embeddings, num_embeddings, weight)
if weight is None:
self.reset_parameters()
def reset_parameters(self) -> None:
init.normal_(self.weight)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input: Tensor) -> Tensor:
return F.embedding(input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
@staticmethod
def from_native_module(
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> PaddingParallelModule:
r"""
Convert a native pytorch embedding module to a parallel module.
"""
LazyInitContext.materialize(module)
# get the origin attributes
num_embeddings = module.num_embeddings
embedding_dim = module.embedding_dim
padding_idx = module.padding_idx
device = module.weight.device
# create the parallel module
padding_embedding = PaddingEmbedding(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
padding_idx=padding_idx,
device=device,
weight=module.weight,
*args,
**kwargs,
)
return padding_embedding
class VocabParallelEmbedding1D(PaddingParallelModule):
r"""Embedding parallelized in the vocabulary dimension. r"""Embedding parallelized in the vocabulary dimension.
Args: Args:
...@@ -201,10 +274,10 @@ class VocabParallelEmbedding1D(ParallelModule): ...@@ -201,10 +274,10 @@ class VocabParallelEmbedding1D(ParallelModule):
process_group: ProcessGroup = None, process_group: ProcessGroup = None,
weight: Optional[nn.Parameter] = None, weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(), weight_initializer: Callable = init.normal_(),
make_vocab_size_divisible_by: int = 64,
*args, *args,
**kwargs, **kwargs,
): ):
super().__init__()
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.embed_args = args self.embed_args = args
...@@ -214,8 +287,23 @@ class VocabParallelEmbedding1D(ParallelModule): ...@@ -214,8 +287,23 @@ class VocabParallelEmbedding1D(ParallelModule):
tensor_parallel_size = dist.get_world_size(group=process_group) tensor_parallel_size = dist.get_world_size(group=process_group)
tensor_parallel_rank = dist.get_rank(group=process_group) tensor_parallel_rank = dist.get_rank(group=process_group)
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size) # generate weight and bias
self.num_embeddings = self.num_embeddings_per_partition if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
# calculate new padding size
multiple = make_vocab_size_divisible_by * tensor_parallel_size
if num_embeddings % multiple != 0:
self.num_embeddings = num_embeddings + multiple - (num_embeddings % multiple)
# resize vocabulary size
super().__init__(self.num_embeddings, num_embeddings, weight)
# deal with tensor parallelism
self.num_embeddings_per_partition = divide(self.num_embeddings, tensor_parallel_size)
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
...@@ -226,13 +314,6 @@ class VocabParallelEmbedding1D(ParallelModule): ...@@ -226,13 +314,6 @@ class VocabParallelEmbedding1D(ParallelModule):
seed = torch.random.initial_seed() seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group) self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
# parameter
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight): if not is_distributed_tensor(self.weight):
sharded_weight = shard_rowwise(self.weight.data, process_group) sharded_weight = shard_rowwise(self.weight.data, process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight) sharded_tensor_to_existing_param(sharded_weight, self.weight)
...@@ -243,7 +324,7 @@ class VocabParallelEmbedding1D(ParallelModule): ...@@ -243,7 +324,7 @@ class VocabParallelEmbedding1D(ParallelModule):
@staticmethod @staticmethod
def from_native_module( def from_native_module(
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule: ) -> PaddingParallelModule:
r""" r"""
Convert a native pytorch embedding module to a parallel module. Convert a native pytorch embedding module to a parallel module.
""" """
...@@ -303,11 +384,9 @@ class VocabParallelEmbedding1D(ParallelModule): ...@@ -303,11 +384,9 @@ class VocabParallelEmbedding1D(ParallelModule):
# Mask the input. # Mask the input.
masked_input = input_.clone() - self.vocab_start_index masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0 masked_input[input_mask] = 0
output_parallel = F.embedding( output_parallel = F.embedding(
masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs
) )
# Mask the output embedding. # Mask the output embedding.
embedding_output = output_parallel.clone() embedding_output = output_parallel.clone()
embedding_output[input_mask, :] = 0.0 embedding_output[input_mask, :] = 0.0
......
...@@ -32,7 +32,7 @@ from ._operation import ( ...@@ -32,7 +32,7 @@ from ._operation import (
reducescatter_forward_gather_backward, reducescatter_forward_gather_backward,
split_forward_gather_backward, split_forward_gather_backward,
) )
from .parallel_module import ParallelModule from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset from .utils import create_randomizer_with_offset
__all__ = ["Linear1D_Col", "Linear1D_Row"] __all__ = ["Linear1D_Col", "Linear1D_Row"]
...@@ -84,8 +84,9 @@ class Linear1D_Col(ParallelModule): ...@@ -84,8 +84,9 @@ class Linear1D_Col(ParallelModule):
bias_: Optional[Parameter] = None, bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
**kwargs,
): ):
super().__init__() super().__init__(weight=weight, bias_=bias_, **kwargs)
# Keep input parameters # Keep input parameters
self.in_features = in_features self.in_features = in_features
...@@ -118,6 +119,7 @@ class Linear1D_Col(ParallelModule): ...@@ -118,6 +119,7 @@ class Linear1D_Col(ParallelModule):
else: else:
weight.data = weight.data.to(device=device, dtype=dtype) weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight self.weight = weight
if not is_distributed_tensor(self.weight): if not is_distributed_tensor(self.weight):
sharded_weight = shard_rowwise(self.weight.data, self.process_group) sharded_weight = shard_rowwise(self.weight.data, self.process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight) sharded_tensor_to_existing_param(sharded_weight, self.weight)
...@@ -140,7 +142,7 @@ class Linear1D_Col(ParallelModule): ...@@ -140,7 +142,7 @@ class Linear1D_Col(ParallelModule):
@staticmethod @staticmethod
def from_native_module( def from_native_module(
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> ParallelModule: ) -> ParallelModule:
r""" r"""
Convert a native PyTorch linear layer to a parallelized linear layer. Convert a native PyTorch linear layer to a parallelized linear layer.
...@@ -173,7 +175,6 @@ class Linear1D_Col(ParallelModule): ...@@ -173,7 +175,6 @@ class Linear1D_Col(ParallelModule):
process_group=process_group, process_group=process_group,
weight=module.weight, weight=module.weight,
bias_=module.bias, bias_=module.bias,
*args,
**kwargs, **kwargs,
) )
...@@ -322,7 +323,7 @@ class Linear1D_Row(ParallelModule): ...@@ -322,7 +323,7 @@ class Linear1D_Row(ParallelModule):
@staticmethod @staticmethod
def from_native_module( def from_native_module(
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> ParallelModule: ) -> ParallelModule:
r""" r"""
Convert a native PyTorch linear layer to a parallelized linear layer. Convert a native PyTorch linear layer to a parallelized linear layer.
...@@ -356,7 +357,6 @@ class Linear1D_Row(ParallelModule): ...@@ -356,7 +357,6 @@ class Linear1D_Row(ParallelModule):
process_group=process_group, process_group=process_group,
weight=module.weight, weight=module.weight,
bias_=module.bias, bias_=module.bias,
*args,
**kwargs, **kwargs,
) )
...@@ -439,3 +439,211 @@ class Linear1D_Row(ParallelModule): ...@@ -439,3 +439,211 @@ class Linear1D_Row(ParallelModule):
return output return output
else: else:
return output, self.bias return output, self.bias
class PaddingLMHead(PaddingParallelModule):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
make_vocab_size_divisible_by: int = 64,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
):
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
if out_features % make_vocab_size_divisible_by != 0:
self.out_features = (
out_features + make_vocab_size_divisible_by - (out_features % make_vocab_size_divisible_by)
)
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
else:
bias_ = None
# resize embeddings
super().__init__(self.out_features, out_features, weight, bias_)
if weight is None:
self.reset_parameters(weight_initializer, bias_initializer)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
@staticmethod
def from_native_module(
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> PaddingParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
lm_head_linear = PaddingLMHead(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
return lm_head_linear
def forward(self, input: Tensor) -> Tensor:
output = F.linear(input, self.weight, self.bias)
output = output[..., : self.old_num_embeddings]
return output
class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
r"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (`typing.Callable`):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
make_vocab_size_divisible_by: int = 64,
**kwargs,
):
# create weight and bias
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs))
if bias:
if bias_ is None:
bias_ = Parameter(torch.empty(out_features, **factory_kwargs))
else:
bias_ = None
# calculate new vocab size
self.tensor_parallel_size = dist.get_world_size(group=process_group)
new_out_features = out_features
multiple = make_vocab_size_divisible_by * self.tensor_parallel_size
if out_features % multiple != 0:
new_out_features = out_features + multiple - (out_features % multiple)
super().__init__(
in_features=in_features,
out_features=new_out_features,
bias=bias,
device=device,
process_group=process_group,
weight=weight,
bias_=bias_,
**kwargs,
new_num_embeddings=new_out_features,
old_num_embeddings=out_features,
)
# get the length of valid embeddings
tp_rank = dist.get_rank(process_group)
partition_size = self.new_num_embeddings // dist.get_world_size(process_group)
if self.old_num_embeddings >= (tp_rank + 1) * partition_size:
self.num_valid_embeddings_local = partition_size
elif self.old_num_embeddings >= tp_rank * partition_size:
self.num_valid_embeddings_local = self.old_num_embeddings - tp_rank * partition_size
else:
self.num_valid_embeddings_local = 0
@staticmethod
def from_native_module(
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> PaddingParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
lm_head_linear = VocabParallelLMHead1D(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
return lm_head_linear
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# get forward output
if self.skip_bias_add:
output, bias = super().forward(input_)
else:
output = super().forward(input_)
# delete the padding of output
if self.gather_output:
output = output[..., : self.old_num_embeddings]
else:
output = output[..., : self.num_valid_embeddings_local]
# return
if self.skip_bias_add:
return output, bias
return output
...@@ -15,7 +15,14 @@ class DistCrossEntropy(Function): ...@@ -15,7 +15,14 @@ class DistCrossEntropy(Function):
""" """
@staticmethod @staticmethod
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup): def forward(
ctx,
vocab_logits: torch.Tensor,
target: torch.Tensor,
ignore_index: int,
process_group: ProcessGroup,
vocab_size: int,
):
r""" r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows: Calculate the cross entropy loss before gather, the origin loss function is as follows:
loss = -log(exp(x[class])/sum(exp(x[i])) loss = -log(exp(x[class])/sum(exp(x[i]))
...@@ -41,15 +48,21 @@ class DistCrossEntropy(Function): ...@@ -41,15 +48,21 @@ class DistCrossEntropy(Function):
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
# mask the target in the local device # mask the target in the local device
partition_vocab_size = vocab_logits.size()[-1]
rank = dist.get_rank(group=process_group) rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group) world_size = dist.get_world_size(group=process_group)
global_vocab_size = partition_vocab_size * world_size if vocab_size == None:
partition_vocab_size = vocab_logits.size()[-1]
global_vocab_size = partition_vocab_size * world_size
else:
global_vocab_size = vocab_size
partition_vocab_size = global_vocab_size // world_size
# [down, up) => false, other device and -100 => true # [down, up) => false, other device and -100 => true
delta = (global_vocab_size + world_size - 1) // world_size delta = (global_vocab_size + world_size - 1) // world_size
down_threshold = rank * delta down_threshold = rank * delta
up_threshold = down_threshold + delta up_threshold = down_threshold + delta
if up_threshold > global_vocab_size:
up_threshold = global_vocab_size
mask = (target < down_threshold) | (target >= up_threshold) mask = (target < down_threshold) | (target >= up_threshold)
masked_target = target.clone() - down_threshold masked_target = target.clone() - down_threshold
masked_target[mask] = 0 masked_target[mask] = 0
...@@ -57,7 +70,8 @@ class DistCrossEntropy(Function): ...@@ -57,7 +70,8 @@ class DistCrossEntropy(Function):
# reshape the logits and target # reshape the logits and target
# reshape the vocab_logits to [bath_size * seq_len, vocab_size] # reshape the vocab_logits to [bath_size * seq_len, vocab_size]
# reshape the labels to [bath_size * seq_len] # reshape the labels to [bath_size * seq_len]
logits_2d = vocab_logits.view(-1, partition_vocab_size) self_vocab_size = vocab_logits.size()[-1]
logits_2d = vocab_logits.view(-1, self_vocab_size)
masked_target_1d = masked_target.view(-1) masked_target_1d = masked_target.view(-1)
# extract the x[class] and set the x[other device] to zero # extract the x[class] and set the x[other device] to zero
...@@ -104,10 +118,14 @@ class DistCrossEntropy(Function): ...@@ -104,10 +118,14 @@ class DistCrossEntropy(Function):
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
grad_logits.mul_(grad_output.unsqueeze(dim=-1)) grad_logits.mul_(grad_output.unsqueeze(dim=-1))
return grad_logits, None, None, None return grad_logits, None, None, None, None
def cross_entropy_1d( def cross_entropy_1d(
vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None vocab_logits: torch.Tensor,
labels: torch.Tensor,
ignore_index: int = -100,
process_group: ProcessGroup = None,
vocab_size: int = None,
) -> torch.Tensor: ) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group) return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size)
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import itertools import itertools
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Union from typing import List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -20,11 +20,15 @@ from colossalai.tensor.d_tensor import ( ...@@ -20,11 +20,15 @@ from colossalai.tensor.d_tensor import (
is_distributed_tensor, is_distributed_tensor,
sharded_tensor_to_param, sharded_tensor_to_param,
) )
from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
__all__ = ["ParallelModule"] __all__ = ["ParallelModule"]
class ParallelModule(nn.Module, ABC): class ParallelModule(nn.Module, ABC):
def __init__(self, **kwargs):
super().__init__()
@abstractmethod @abstractmethod
def from_native_module( def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None
...@@ -54,7 +58,7 @@ class ParallelModule(nn.Module, ABC): ...@@ -54,7 +58,7 @@ class ParallelModule(nn.Module, ABC):
""" """
for name, param in self._parameters.items(): for name, param in self._parameters.items():
if param is not None: if param is not None:
destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars) destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars).data
for name, buf in self._buffers.items(): for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set: if buf is not None and name not in self._non_persistent_buffers_set:
...@@ -171,3 +175,187 @@ class ParallelModule(nn.Module, ABC): ...@@ -171,3 +175,187 @@ class ParallelModule(nn.Module, ABC):
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state: if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key) unexpected_keys.append(key)
class PaddingParallelModule(ParallelModule):
def __init__(
self,
new_num_embeddings: int,
old_num_embeddings: int,
weight: Optional[nn.Parameter],
bias_: Optional[nn.Parameter] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.new_num_embeddings = new_num_embeddings
self.old_num_embeddings = old_num_embeddings
self.weight = weight
self.bias = bias_
if not (is_distributed_tensor(self.weight) or self.weight.shape[0] == self.new_num_embeddings):
self.resize_embedding_weight()
if self.bias is not None and not (
is_distributed_tensor(self.bias) or self.bias.shape[0] == self.new_num_embeddings
):
self.resize_embedding_bias()
@abstractmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None
) -> "PaddingParallelModule":
"""
Convert a native PyTorch module to a parallelized module.
Args:
module (nn.Module): the module to be converted.
process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication.
If this is a list, the process group at the ith index of the list will correspond to the process group
in the ith axis of the device mesh. Defaults to None, which means the global process group.
"""
raise NotImplementedError
def _save_to_state_dict(self, destination, prefix, keep_vars):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
for name, param in self._parameters.items():
if param is not None:
param = gather_distributed_param(param, keep_vars=keep_vars)
if is_padded_tensor(param):
param = to_unpadded_tensor(param)
destination[prefix + name] = param.data
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
destination[extra_state_key] = self.get_extra_state()
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.
.. note::
:attr:`state_dict` is not the same object as the input
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
it can be modified.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
"""
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if not torch.overrides.is_tensor_like(input_param):
error_msgs.append(
'While copying the parameter named "{}", '
"expected torch.Tensor or Tensor-like object from checkpoint but "
"received {}".format(key, type(input_param))
)
continue
if is_padded_tensor(param):
input_param = to_padded_tensor(input_param, param._current_length, param._padding_dim)
if is_distributed_tensor(param):
# shard the input param
device_mesh = get_device_mesh(param)
sharding_spec = get_sharding_spec(param)
sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec)
input_param = sharded_tensor_to_param(sharded_tensor)
elif is_customized_distributed_tensor(param):
input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn)
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append(
"size mismatch for {}: copying a param with shape {} from checkpoint, "
"the shape in current model is {}.".format(key, input_param.shape, param.shape)
)
continue
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
elif strict:
missing_keys.append(key)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix) :]
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
def resize_embedding_weight(self):
self.weight = to_padded_tensor(self.weight, self.new_num_embeddings, 0)
def resize_embedding_bias(self):
self.bias = to_padded_tensor(self.bias, self.new_num_embeddings, 0)
...@@ -26,7 +26,6 @@ from colossalai.shardformer.layer._operation import gather_forward_split_backwar ...@@ -26,7 +26,6 @@ from colossalai.shardformer.layer._operation import gather_forward_split_backwar
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d from ..layer import cross_entropy_1d
from ..layer._operation import gather_forward_split_backward
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -397,13 +396,11 @@ class GPT2PipelineForwards: ...@@ -397,13 +396,11 @@ class GPT2PipelineForwards:
shift_logits, shift_logits,
shift_labels, shift_labels,
process_group=shard_config.tensor_parallel_process_group, process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
) )
else: else:
loss = loss_fct(shift_logits, shift_labels) loss = loss_fct(shift_logits, shift_labels)
if not shard_config.parallel_output:
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
if not return_dict: if not return_dict:
output = (lm_logits,) + outputs[1:] output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1301,12 +1298,12 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): ...@@ -1301,12 +1298,12 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_logits = shift_logits.view(-1, shift_logits.size(-1)) shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1) shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d( loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
) )
if not shard_config.parallel_output:
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
if not return_dict: if not return_dict:
output = (lm_logits,) + transformer_outputs[1:] output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -316,7 +316,10 @@ class LlamaPipelineForwards: ...@@ -316,7 +316,10 @@ class LlamaPipelineForwards:
new_vocab_size = logits.shape[-1] new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size) shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d( loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
) )
else: else:
shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_logits = shift_logits.view(-1, self.config.vocab_size)
...@@ -735,11 +738,13 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): ...@@ -735,11 +738,13 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_labels = shift_labels.view(-1) shift_labels = shift_labels.view(-1)
# Enable model parallelism # Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device) shift_labels = shift_labels.to(shift_logits.device)
new_vocab_size = logits.shape[-1] new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size) shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d( loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
) )
if not return_dict: if not return_dict:
......
...@@ -195,3 +195,12 @@ class Policy(ABC): ...@@ -195,3 +195,12 @@ class Policy(ABC):
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
""" """
return [] return []
def tie_weight_check(self):
input_embedding = self.model.get_input_embeddings()
output_embedding = self.model.get_output_embeddings()
return (
input_embedding is not None
and output_embedding is not None
and id(input_embedding.weight) == id(output_embedding.weight)
)
...@@ -37,17 +37,7 @@ class BertPolicy(Policy): ...@@ -37,17 +37,7 @@ class BertPolicy(Policy):
pass pass
def preprocess(self): def preprocess(self):
# reshape the embedding layer self.tie_weight = self.tie_weight_check()
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self): def module_policy(self):
...@@ -62,6 +52,13 @@ class BertPolicy(Policy): ...@@ -62,6 +52,13 @@ class BertPolicy(Policy):
policy = {} policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm norm_cls = col_nn.FusedLayerNorm
else: else:
...@@ -150,10 +147,6 @@ class BertPolicy(Policy): ...@@ -150,10 +147,6 @@ class BertPolicy(Policy):
policy[BertEmbeddings] = ModulePolicyDescription( policy[BertEmbeddings] = ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="dropout", suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput, target_module=col_nn.DropoutForReplicatedInput,
...@@ -168,6 +161,18 @@ class BertPolicy(Policy): ...@@ -168,6 +161,18 @@ class BertPolicy(Policy):
target_key=BertModel, target_key=BertModel,
) )
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=embedding_cls,
)
],
policy=policy,
target_key=BertEmbeddings,
)
# optimization configuration # optimization configuration
# Handle bert layer # Handle bert layer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
...@@ -237,8 +242,21 @@ class BertPolicy(Policy): ...@@ -237,8 +242,21 @@ class BertPolicy(Policy):
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="decoder", suffix="decoder",
target_module=col_nn.Linear1D_Col, target_module=col_nn.VocabParallelLMHead1D,
kwargs={"gather_output": True}, kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
),
policy=base_policy,
target_key=BertLMPredictionHead,
)
else:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="decoder",
target_module=col_nn.PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
), ),
policy=base_policy, policy=base_policy,
target_key=BertLMPredictionHead, target_key=BertLMPredictionHead,
......
...@@ -17,16 +17,7 @@ class BlipPolicy(Policy): ...@@ -17,16 +17,7 @@ class BlipPolicy(Policy):
pass pass
def preprocess(self): def preprocess(self):
# reshape the embedding layer self.tie_weight = self.tie_weight_check()
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
vocab_size = self.model.config.qformer_config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self): def module_policy(self):
...@@ -43,6 +34,13 @@ class BlipPolicy(Policy): ...@@ -43,6 +34,13 @@ class BlipPolicy(Policy):
policy = {} policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm norm_cls = col_nn.FusedLayerNorm
else: else:
...@@ -202,22 +200,48 @@ class BlipPolicy(Policy): ...@@ -202,22 +200,48 @@ class BlipPolicy(Policy):
], ],
) )
policy[OPTForCausalLM] = ModulePolicyDescription( policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
sub_module_replacement=[
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="model.decoder.embed_tokens", suffix="model.decoder.embed_tokens",
target_module=col_nn.VocabParallelEmbedding1D, target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
), ),
],
policy=policy,
target_key=OPTForCausalLM,
)
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", suffix="lm_head",
target_module=col_nn.Linear1D_Col, target_module=col_nn.VocabParallelLMHead1D,
kwargs={"gather_output": True}, kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
), ),
] ],
policy=policy,
target_key=OPTForCausalLM,
)
else:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
],
policy=policy,
target_key=OPTForCausalLM,
) )
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
# optimization configuration # optimization configuration
# Handle Blip2EncoderLayer layer # Handle Blip2EncoderLayer layer
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
......
...@@ -35,16 +35,7 @@ class BloomPolicy(Policy): ...@@ -35,16 +35,7 @@ class BloomPolicy(Policy):
pass pass
def preprocess(self): def preprocess(self):
# reshape the embedding layer self.tie_weight = self.tie_weight_check()
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self): def module_policy(self):
...@@ -52,6 +43,13 @@ class BloomPolicy(Policy): ...@@ -52,6 +43,13 @@ class BloomPolicy(Policy):
policy = {} policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm norm_cls = col_nn.FusedLayerNorm
else: else:
...@@ -112,12 +110,19 @@ class BloomPolicy(Policy): ...@@ -112,12 +110,19 @@ class BloomPolicy(Policy):
method_replacement={ method_replacement={
"build_alibi_tensor": build_bloom_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group) "build_alibi_tensor": build_bloom_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group)
}, },
sub_module_replacement=[ )
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="word_embeddings", suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D, target_module=embedding_cls,
) kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
], ],
policy=policy,
target_key=BloomModel,
) )
# optimization configuration # optimization configuration
...@@ -282,7 +287,21 @@ class BloomForCausalLMPolicy(BloomPolicy): ...@@ -282,7 +287,21 @@ class BloomForCausalLMPolicy(BloomPolicy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) suffix="lm_head",
target_module=col_nn.VocabParallelLMHead1D,
kwargs=dict(
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
),
),
policy=policy,
target_key=BloomForCausalLM,
)
else:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.PaddingLMHead,
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
), ),
policy=policy, policy=policy,
target_key=BloomForCausalLM, target_key=BloomForCausalLM,
......
...@@ -25,20 +25,12 @@ class ChatGLMPolicy(Policy): ...@@ -25,20 +25,12 @@ class ChatGLMPolicy(Policy):
pass pass
def preprocess(self): def preprocess(self):
# Resize embedding
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.padded_vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
# the batch_size_dim is bounded to Model # the batch_size_dim is bounded to Model
bsz_dim = 1 bsz_dim = 1
setattr(self.model, "batch_size_dim", bsz_dim) setattr(self.model, "batch_size_dim", bsz_dim)
self.tie_weight = self.tie_weight_check()
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
...@@ -46,6 +38,13 @@ class ChatGLMPolicy(Policy): ...@@ -46,6 +38,13 @@ class ChatGLMPolicy(Policy):
policy = {} policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
if self.model.config.rmsnorm: if self.model.config.rmsnorm:
norm_cls = col_nn.FusedRMSNorm norm_cls = col_nn.FusedRMSNorm
...@@ -68,16 +67,6 @@ class ChatGLMPolicy(Policy): ...@@ -68,16 +67,6 @@ class ChatGLMPolicy(Policy):
sp_partial_derived = sp_mode == "split_gather" sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[ChatGLMModel] = ModulePolicyDescription(
attribute_replacement={},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embedding.word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
],
)
policy[GLMBlock] = ModulePolicyDescription( policy[GLMBlock] = ModulePolicyDescription(
attribute_replacement={ attribute_replacement={
"self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads "self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
...@@ -114,6 +103,19 @@ class ChatGLMPolicy(Policy): ...@@ -114,6 +103,19 @@ class ChatGLMPolicy(Policy):
), ),
], ],
) )
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="embedding.word_embeddings",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
],
policy=policy,
target_key=ChatGLMModel,
)
# optimization configuration # optimization configuration
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
......
...@@ -32,16 +32,7 @@ class FalconPolicy(Policy): ...@@ -32,16 +32,7 @@ class FalconPolicy(Policy):
pass pass
def preprocess(self): def preprocess(self):
# reshape the embedding layer self.tie_weight = self.tie_weight_check()
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self): def module_policy(self):
...@@ -58,6 +49,14 @@ class FalconPolicy(Policy): ...@@ -58,6 +49,14 @@ class FalconPolicy(Policy):
warnings.warn("Falcon doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") warnings.warn("Falcon doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
policy = {} policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
attn_attribute_replacement = { attn_attribute_replacement = {
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
...@@ -98,12 +97,19 @@ class FalconPolicy(Policy): ...@@ -98,12 +97,19 @@ class FalconPolicy(Policy):
method_replacement={ method_replacement={
"build_alibi_tensor": build_falcon_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group) "build_alibi_tensor": build_falcon_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group)
}, },
sub_module_replacement=[ )
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="word_embeddings", suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D, target_module=embedding_cls,
) kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
], ],
policy=policy,
target_key=FalconModel,
) )
# optimization configuration # optimization configuration
...@@ -232,11 +238,26 @@ class FalconForCausalLMPolicy(FalconPolicy): ...@@ -232,11 +238,26 @@ class FalconForCausalLMPolicy(FalconPolicy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True) suffix="lm_head",
target_module=col_nn.VocabParallelLMHead1D,
kwargs=dict(
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
),
), ),
policy=policy, policy=policy,
target_key=FalconForCausalLM, target_key=FalconForCausalLM,
) )
else:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.PaddingLMHead,
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
),
policy=policy,
target_key=FalconForCausalLM,
)
if self.pipeline_stage_manager: if self.pipeline_stage_manager:
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=FalconForCausalLM, model_cls=FalconForCausalLM,
......
...@@ -34,12 +34,7 @@ class GPT2Policy(Policy): ...@@ -34,12 +34,7 @@ class GPT2Policy(Policy):
r""" r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size Reshape the Embedding layer to make the embedding dimension divisible by world_size
""" """
if self.shard_config.enable_tensor_parallelism: self.tie_weight = self.tie_weight_check()
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self): def module_policy(self):
...@@ -47,6 +42,13 @@ class GPT2Policy(Policy): ...@@ -47,6 +42,13 @@ class GPT2Policy(Policy):
policy = {} policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm norm_cls = col_nn.FusedLayerNorm
else: else:
...@@ -73,10 +75,6 @@ class GPT2Policy(Policy): ...@@ -73,10 +75,6 @@ class GPT2Policy(Policy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[GPT2Model] = ModulePolicyDescription( policy[GPT2Model] = ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="drop", suffix="drop",
target_module=col_nn.DropoutForParallelInput, target_module=col_nn.DropoutForParallelInput,
...@@ -137,6 +135,17 @@ class GPT2Policy(Policy): ...@@ -137,6 +135,17 @@ class GPT2Policy(Policy):
), ),
], ],
) )
if embedding_cls is not None:
# padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="wte",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=GPT2Model,
)
# optimization configuration # optimization configuration
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
...@@ -298,8 +307,11 @@ class GPT2LMHeadModelPolicy(GPT2Policy): ...@@ -298,8 +307,11 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", suffix="lm_head",
target_module=col_nn.Linear1D_Col, target_module=col_nn.VocabParallelLMHead1D,
kwargs={"gather_output": not self.shard_config.parallel_output}, kwargs={
"gather_output": False,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
) )
], ],
) )
...@@ -308,7 +320,19 @@ class GPT2LMHeadModelPolicy(GPT2Policy): ...@@ -308,7 +320,19 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
addon_module[GPT2LMHeadModel].method_replacement = { addon_module[GPT2LMHeadModel].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
} }
module_policy.update(addon_module) else:
addon_module = {
GPT2LMHeadModel: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
)
]
)
}
module_policy.update(addon_module)
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
self.set_pipeline_forward( self.set_pipeline_forward(
...@@ -353,13 +377,28 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy): ...@@ -353,13 +377,28 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", suffix="lm_head",
target_module=col_nn.Linear1D_Col, target_module=col_nn.VocabParallelLMHead1D,
kwargs={"gather_output": True}, kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
) )
] ]
) )
} }
module_policy.update(addon_module) else:
addon_module = {
GPT2DoubleHeadsModel: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
)
]
)
}
module_policy.update(addon_module)
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
self.set_pipeline_forward( self.set_pipeline_forward(
......
...@@ -29,22 +29,21 @@ class GPTJPolicy(Policy): ...@@ -29,22 +29,21 @@ class GPTJPolicy(Policy):
pass pass
def preprocess(self): def preprocess(self):
# reshape the embedding layer self.tie_weight = self.tie_weight_check()
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self): def module_policy(self):
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel
policy = {} policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_parallelism = False
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
...@@ -54,10 +53,6 @@ class GPTJPolicy(Policy): ...@@ -54,10 +53,6 @@ class GPTJPolicy(Policy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[GPTJModel] = ModulePolicyDescription( policy[GPTJModel] = ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="drop", suffix="drop",
target_module=col_nn.DropoutForParallelInput, target_module=col_nn.DropoutForParallelInput,
...@@ -126,6 +121,17 @@ class GPTJPolicy(Policy): ...@@ -126,6 +121,17 @@ class GPTJPolicy(Policy):
], ],
) )
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="wte",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=GPTJModel,
)
# optimization configuration # optimization configuration
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
...@@ -255,13 +261,28 @@ class GPTJForCausalLMPolicy(GPTJPolicy): ...@@ -255,13 +261,28 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", suffix="lm_head",
target_module=col_nn.Linear1D_Col, target_module=col_nn.VocabParallelLMHead1D,
kwargs={"gather_output": True}, kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
)
]
)
}
else:
addon_module = {
GPTJForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
) )
] ]
) )
} }
policy.update(addon_module) policy.update(addon_module)
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
self.set_pipeline_forward( self.set_pipeline_forward(
......
...@@ -6,7 +6,16 @@ import torch.nn as nn ...@@ -6,7 +6,16 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D from colossalai.shardformer.layer import (
FusedRMSNorm,
Linear1D_Col,
Linear1D_Row,
PaddingEmbedding,
PaddingLMHead,
RMSNorm,
VocabParallelEmbedding1D,
VocabParallelLMHead1D,
)
from ..modeling.llama import ( from ..modeling.llama import (
LlamaPipelineForwards, LlamaPipelineForwards,
...@@ -26,15 +35,7 @@ class LlamaPolicy(Policy): ...@@ -26,15 +35,7 @@ class LlamaPolicy(Policy):
pass pass
def preprocess(self): def preprocess(self):
if self.shard_config.enable_tensor_parallelism: self.tie_weight = self.tie_weight_check()
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
...@@ -42,6 +43,13 @@ class LlamaPolicy(Policy): ...@@ -42,6 +43,13 @@ class LlamaPolicy(Policy):
policy = {} policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
norm_cls = FusedRMSNorm norm_cls = FusedRMSNorm
else: else:
...@@ -167,10 +175,12 @@ class LlamaPolicy(Policy): ...@@ -167,10 +175,12 @@ class LlamaPolicy(Policy):
], ],
) )
if embedding_cls is not None:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="embed_tokens", suffix="embed_tokens",
target_module=VocabParallelEmbedding1D, target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
), ),
policy=policy, policy=policy,
target_key=LlamaModel, target_key=LlamaModel,
...@@ -327,8 +337,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy): ...@@ -327,8 +337,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", suffix="lm_head",
target_module=Linear1D_Col, target_module=VocabParallelLMHead1D,
kwargs={"gather_output": not self.shard_config.parallel_output}, kwargs={
"gather_output": not self.shard_config.parallel_output,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
) )
], ],
) )
...@@ -337,7 +350,19 @@ class LlamaForCausalLMPolicy(LlamaPolicy): ...@@ -337,7 +350,19 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
new_item[LlamaForCausalLM].method_replacement = { new_item[LlamaForCausalLM].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config) "forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
} }
policy.update(new_item) else:
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
)
],
)
}
policy.update(new_item)
if self.pipeline_stage_manager: if self.pipeline_stage_manager:
# set None as default # set None as default
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment