Unverified Commit a0ad587c authored by flybird11111's avatar flybird11111 Committed by GitHub
Browse files

[shardformer] refactor embedding resize (#5603)



* [branch rebase] rebase main to Feature/resize_embedding (#5554)

* fix

* [release] update version (#5411)

* [hotfix] fix typo s/keywrods/keywords etc. (#5429)

* [devops] fix compatibility (#5444)

* [devops] fix compatibility

* [hotfix] update compatibility test on pr

* [devops] fix compatibility

* [devops] record duration during comp test

* [test] decrease test duration

* fix falcon

* [shardformer] fix gathering output when using tensor parallelism (#5431)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* [doc] release Open-Sora 1.0 with model weights (#5468)

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] update open-sora demo (#5479)

* [doc] update open-sora demo

* [doc] update open-sora demo

* [doc] update open-sora demo

* [example] add grok-1 inference (#5485)

* [misc] add submodule

* remove submodule

* [example] support grok-1 tp inference

* [example] add grok-1 inference script

* [example] refactor code

* [example] add grok-1 readme

* [exmaple] add test ci

* [exmaple] update readme

---------
Co-authored-by: default avatarHongxin Liu <lhx0217@gmail.com>
Co-authored-by: default avatardigger yu <digger-yu@outlook.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

* [CI] run pre-commit (#5577)

* fix

* [release] update version (#5411)

* [hotfix] fix typo s/keywrods/keywords etc. (#5429)

* [devops] fix compatibility (#5444)

* [devops] fix compatibility

* [hotfix] update compatibility test on pr

* [devops] fix compatibility

* [devops] record duration during comp test

* [test] decrease test duration

* fix falcon

* [shardformer] fix gathering output when using tensor parallelism (#5431)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* [doc] release Open-Sora 1.0 with model weights (#5468)

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] update open-sora demo (#5479)

* [doc] update open-sora demo

* [doc] update open-sora demo

* [doc] update open-sora demo

* [example] add grok-1 inference (#5485)

* [misc] add submodule

* remove submodule

* [example] support grok-1 tp inference

* [example] add grok-1 inference script

* [example] refactor code

* [example] add grok-1 readme

* [exmaple] add test ci

* [exmaple] update readme

* run pre-commit

---------
Co-authored-by: default avatarHongxin Liu <lhx0217@gmail.com>
Co-authored-by: default avatardigger yu <digger-yu@outlook.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

* [rebase] rebase main to resize-embedding (#5581)

* [release] grok-1 314b inference (#5490)

* [release] grok-1 inference

* [release] grok-1 inference

* [release] grok-1 inference

* [example] update Grok-1 inference (#5495)

* revise grok-1 example

* remove unused arg in scripts

* prevent re-installing torch

* update readme

* revert modifying colossalai requirements

* add perf

* trivial

* add tokenizer url

* [hotfix] set return_outputs=False in examples and polish code (#5404)

* fix: simplify merge_batch

* fix: use return_outputs=False to eliminate extra memory consumption

* feat: add return_outputs warning

* style: remove `return_outputs=False` as it is the default value

* [release] grok-1 inference benchmark (#5500)

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [shardformer]Fix lm parallel. (#5480)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* fix lm forward distribution

* fix

* test ci

* fix

* [fix] fix grok-1 example typo (#5506)

* [devops] fix example test ci (#5504)

* Fix ColoTensorSpec for py11 (#5440)

* fixed layout converter caching and updated tester

* Empty-Commit

* [shardformer] update colo attention to support custom mask (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests

* [format] applied code formatting on changed files in pull request 5510 (#5517)
Co-authored-by: default avatargithub-actions <github-actions@github.com>

* [shardformer] fix pipeline forward error if custom layer distribution is used (#5189)

* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution

* Change static methods for t5 layer distribution to member functions

* Change static methods for whisper layer distribution to member functions

* Replace whisper policy usage with self one

* Fix test case to use non-static layer distribution methods

* fix: fix typo

---------
Co-authored-by: default avatarWenhao Chen <cwher@outlook.com>

* [Fix] Grok-1 use tokenizer from the same pretrained path (#5532)

* [fix] use tokenizer from the same pretrained path

* trust remote code

* [ColossalChat] Update RLHF V2 (#5286)

* Add dpo. Fix sft, ppo, lora. Refactor all

* fix and tested ppo

* 2 nd round refactor

* add ci tests

* fix ci

* fix ci

* fix readme, style

* fix readme style

* fix style, fix benchmark

* reproduce benchmark result, remove useless files

* rename to ColossalChat

* use new image

* fix ci workflow

* fix ci

* use local model/tokenizer for ci tests

* fix ci

* fix ci

* fix ci

* fix ci timeout

* fix rm progress bar. fix ci timeout

* fix ci

* fix ci typo

* remove 3d plugin from ci temporary

* test environment

* cannot save optimizer

* support chat template

* fix readme

* fix path

* test ci locally

* restore build_or_pr

* fix ci data path

* fix benchmark

* fix ci, move ci tests to 3080, disable fast tokenizer

* move ci to 85

* support flash attention 2

* add all-in-one data preparation script. Fix colossal-llama2-chat chat template

* add hardware requirements

* move ci test data

* fix save_model, add unwrap

* fix missing bos

* fix missing bos; support grad accumulation with gemini

* fix ci

* fix ci

* fix ci

* fix llama2 chat template config

* debug sft

* debug sft

* fix colossalai version requirement

* fix ci

* add sanity check to prevent NaN loss

* fix requirements

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* update readme

* update readme

* update readme and ignore

* fix logger bug

* support parallel_output

* modify data preparation logic

* fix tokenization

* update lr

* fix inference

* run pre-commit

---------
Co-authored-by: default avatarTong Li <tong.li352711588@gmail.com>

* [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508)

* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`

* feat: apply `GradientCheckpointConfig` to policy and llama_forward

* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager

* fix: add optional args for `distribute_layer` and `get_stage_index`

* fix: fix changed API calls

* test: update llama tests

* style: polish `GradientCheckpointConfig`

* fix: fix pipeline utils tests

* fix incorrect sharding without zero (#5545)
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>

* [shardformer] Sequence Parallelism Optimization (#5533)

* sequence parallel optimization

* validate sequence parallel in llama (code to be polished)

* shardformer api writing

* integrate sequence parallel in ShardFormer

* fix pp bugs and sp bugs for LlaMa model

* integrating ring-based sequence parallelism into ShardFormer

* [sequence parallelism]: Add fused megatron function

* integrating ring-based sequence parallelism into ShardFormer

---------
Co-authored-by: default avatarlinsj20 <linsj20@mails.tsinghua.edu.cn>

* fix bugs when useing sp and flashattention together

* fix operation function name

* support flash attention for ulysses-style sp

* clarify sp process group

* fix compatibility bugs in moe plugin

* fix fused linear bugs

* fix linear layer test

* support gpt model all-to-all sp

* modify shard data dimension (meant to be dim=-1)

* support megtron-style sp and distributed attn for llama model

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* finish sp mode 3 support for gpt

* using all_to_all_single when batch size is 1

* support mode 2 sp in gpt2 (#5)

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* refactor ring implementation

* support mode 2 sp in gpt2

* polish code

* enable distributed attn mask when using sp mode 2 and 3 in llama

* automatically enable flash attn when using sp mode 2 and 3 in llama

* inplace attn mask

* add zero2 support for sequence parallel

* polish code

* fix bugs

* fix gemini checkpoint io

* loose tensor checking atol and rtol

* add comment

* fix llama layernorm grad

* fix zero grad

* fix zero grad

* fix conflict

* update split and gather auto grad func

* sequence parallel: inside text split (#6)

* polish code (part 1)

* polish code (part 2)

* polish code (part 2.5)

* polish code (part 3)

* sequence parallel: inside text split

* miscellaneous minor fixes

* polish code

* fix ulysses style ZeRO

* sequence parallel: inside text split

* miscellaneous minor fixes

* disaggregate sp group and dp group for  sp

* fix llama and gpt sp

* polish code

* move ulysses grad sync to ddp (#9)

* remove zero_stage and unbind the grad sync for alltoall sp

* add 2d group creation test

* move ulysses grad sync to ddp

* add 2d group creation test

* remove useless code

* change shard config not to enable sp when enable_all_optimizations

* add sp warnings for several model

* remove useless code

---------
Co-authored-by: default avatarlinsj20 <linsj20@mails.tsinghua.edu.cn>

* [hotfix] quick fixes to make legacy tutorials runnable (#5559)
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>

* [fix] fix typo s/muiti-node /multi-node etc. (#5448)

* [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548)

* [devops] remove post commit ci (#5566)

* [devops] remove post commit ci

* [misc] run pre-commit on all files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

---------
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>
Co-authored-by: default avatarYuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: default avatarWenhao Chen <cwher@outlook.com>
Co-authored-by: default avatarHongxin Liu <lhx0217@gmail.com>
Co-authored-by: default avatarRocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
Co-authored-by: default avatarEdenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: default avatargithub-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatargithub-actions <github-actions@github.com>
Co-authored-by: default avatarInsu Jang <insujang@umich.edu>
Co-authored-by: default avatarYeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: default avatarTong Li <tong.li352711588@gmail.com>
Co-authored-by: default avatarZhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: default avatarlinsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: default avatardigger yu <digger-yu@outlook.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [shardformer]enable padding vocabulary size. (#5489)

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* padding vocab

* padding vocabe

* fix

* fix

* fxi

* test ci

* fix

fix

fix

fix

* fix

fix

* fix

* fix

* Update hybrid_parallel_plugin.py

fix

fix

fix

* fix

fix

* fix

fix

* fix

* resolve super init

resolve super init

resolve super init

resolve super init

* resolve comments

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* vocab checkpointio

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

fix

fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* padding vocab

* fix

* fix

fix

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* cherry-pick

* revert moe modify

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix

fix

fix

fix

fix

fix

fix

fix

* resolve comments

resolve comments

resolve comments

resolve comments

resolve comments

* ptensor

ptensor

resolve comments

fix

fix

fix

fix

fix

resolve comments

resolve comments

resolve comments

resolve comments

resolve comments

---------
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarHongxin Liu <lhx0217@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix rebase

* fix rebase

---------
Co-authored-by: default avatarHongxin Liu <lhx0217@gmail.com>
Co-authored-by: default avatardigger yu <digger-yu@outlook.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>
Co-authored-by: default avatarYuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: default avatarWenhao Chen <cwher@outlook.com>
Co-authored-by: default avatarRocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: default avatarEdenzzzz <wtan45@wisc.edu>
Co-authored-by: default avatarEdenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: default avatargithub-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: default avatargithub-actions <github-actions@github.com>
Co-authored-by: default avatarInsu Jang <insujang@umich.edu>
Co-authored-by: default avatarYeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: default avatarTong Li <tong.li352711588@gmail.com>
Co-authored-by: default avatarZhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: default avatarlinsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 3788fefc
...@@ -3,7 +3,15 @@ from typing import Dict, Union ...@@ -3,7 +3,15 @@ from typing import Dict, Union
import torch.nn as nn import torch.nn as nn
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from colossalai.shardformer.layer import (
FusedRMSNorm,
Linear1D_Col,
Linear1D_Row,
PaddingEmbedding,
PaddingLMHead,
VocabParallelEmbedding1D,
VocabParallelLMHead1D,
)
from ..modeling.mistral import get_mistral_flash_attention_forward from ..modeling.mistral import get_mistral_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
...@@ -16,15 +24,7 @@ class MistralPolicy(Policy): ...@@ -16,15 +24,7 @@ class MistralPolicy(Policy):
pass pass
def preprocess(self): def preprocess(self):
if self.shard_config.enable_tensor_parallelism: self.tie_weight = self.tie_weight_check()
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
...@@ -32,6 +32,13 @@ class MistralPolicy(Policy): ...@@ -32,6 +32,13 @@ class MistralPolicy(Policy):
policy = {} policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
if self.shard_config.enable_sequence_parallelism: if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False self.shard_config.enable_sequence_parallelism = False
warnings.warn( warnings.warn(
...@@ -80,10 +87,12 @@ class MistralPolicy(Policy): ...@@ -80,10 +87,12 @@ class MistralPolicy(Policy):
], ],
) )
if embedding_cls is not None:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="embed_tokens", suffix="embed_tokens",
target_module=VocabParallelEmbedding1D, target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
), ),
policy=policy, policy=policy,
target_key=MistralModel, target_key=MistralModel,
...@@ -146,6 +155,8 @@ class MistralForCausalLMPolicy(MistralPolicy): ...@@ -146,6 +155,8 @@ class MistralForCausalLMPolicy(MistralPolicy):
from transformers import MistralForCausalLM from transformers import MistralForCausalLM
policy = super().module_policy() policy = super().module_policy()
if self.pipeline_stage_manager:
warnings.warn("Mistral doesn't support pipeline parallelism now.")
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm # add a new item for casual lm
...@@ -153,16 +164,30 @@ class MistralForCausalLMPolicy(MistralPolicy): ...@@ -153,16 +164,30 @@ class MistralForCausalLMPolicy(MistralPolicy):
MistralForCausalLM: ModulePolicyDescription( MistralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) suffix="lm_head",
target_module=VocabParallelLMHead1D,
kwargs=dict(
gather_output=True,
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
),
)
]
)
}
else:
new_item = {
MistralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=PaddingLMHead,
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
) )
] ]
) )
} }
if self.pipeline_stage_manager: policy.update(new_item)
warnings.warn("Mistral doesn't support pipeline parallelism now.")
policy.update(new_item)
return policy return policy
......
...@@ -5,7 +5,16 @@ from typing import Callable, Dict, List ...@@ -5,7 +5,16 @@ from typing import Callable, Dict, List
import torch.nn as nn import torch.nn as nn
from torch import Tensor, nn from torch import Tensor, nn
from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from colossalai.shardformer.layer import (
FusedLayerNorm,
LayerNorm,
Linear1D_Col,
Linear1D_Row,
PaddingEmbedding,
PaddingLMHead,
VocabParallelEmbedding1D,
VocabParallelLMHead1D,
)
from .._utils import getattr_ from .._utils import getattr_
from ..modeling.jit import get_jit_fused_dropout_add_func from ..modeling.jit import get_jit_fused_dropout_add_func
...@@ -41,16 +50,7 @@ class OPTPolicy(Policy): ...@@ -41,16 +50,7 @@ class OPTPolicy(Policy):
pass pass
def preprocess(self): def preprocess(self):
# reshape the embedding layer self.tie_weight = self.tie_weight_check()
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self): def module_policy(self):
...@@ -58,6 +58,13 @@ class OPTPolicy(Policy): ...@@ -58,6 +58,13 @@ class OPTPolicy(Policy):
policy = {} policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
norm_cls = FusedLayerNorm norm_cls = FusedLayerNorm
else: else:
...@@ -68,14 +75,6 @@ class OPTPolicy(Policy): ...@@ -68,14 +75,6 @@ class OPTPolicy(Policy):
warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[OPTDecoder] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
]
)
policy[OPTDecoderLayer] = ModulePolicyDescription( policy[OPTDecoderLayer] = ModulePolicyDescription(
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
...@@ -114,6 +113,17 @@ class OPTPolicy(Policy): ...@@ -114,6 +113,17 @@ class OPTPolicy(Policy):
], ],
) )
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=OPTDecoder,
)
# optimization configuration # optimization configuration
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
...@@ -253,8 +263,20 @@ class OPTForCausalLMPolicy(OPTPolicy): ...@@ -253,8 +263,20 @@ class OPTForCausalLMPolicy(OPTPolicy):
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="lm_head", suffix="lm_head",
target_module=Linear1D_Col, target_module=VocabParallelLMHead1D,
kwargs=dict(gather_output=True), kwargs=dict(
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
),
),
policy=policy,
target_key=OPTForCausalLM,
)
else:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
target_module=PaddingLMHead,
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
), ),
policy=policy, policy=policy,
target_key=OPTForCausalLM, target_key=OPTForCausalLM,
......
...@@ -13,8 +13,11 @@ from colossalai.shardformer.layer import ( ...@@ -13,8 +13,11 @@ from colossalai.shardformer.layer import (
FusedRMSNorm, FusedRMSNorm,
Linear1D_Col, Linear1D_Col,
Linear1D_Row, Linear1D_Row,
PaddingEmbedding,
PaddingLMHead,
RMSNorm, RMSNorm,
VocabParallelEmbedding1D, VocabParallelEmbedding1D,
VocabParallelLMHead1D,
) )
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
...@@ -36,16 +39,7 @@ class T5BasePolicy(Policy): ...@@ -36,16 +39,7 @@ class T5BasePolicy(Policy):
pass pass
def preprocess(self): def preprocess(self):
# reshape the embedding layer self.tie_weight = self.tie_weight_check()
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self): def module_policy(self):
...@@ -61,6 +55,13 @@ class T5BasePolicy(Policy): ...@@ -61,6 +55,13 @@ class T5BasePolicy(Policy):
policy = {} policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
norm_cls = FusedRMSNorm norm_cls = FusedRMSNorm
else: else:
...@@ -77,10 +78,6 @@ class T5BasePolicy(Policy): ...@@ -77,10 +78,6 @@ class T5BasePolicy(Policy):
suffix="dropout", suffix="dropout",
target_module=DropoutForParallelInput, target_module=DropoutForParallelInput,
), ),
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
),
] ]
) )
policy[T5LayerSelfAttention] = ModulePolicyDescription( policy[T5LayerSelfAttention] = ModulePolicyDescription(
...@@ -176,6 +173,17 @@ class T5BasePolicy(Policy): ...@@ -176,6 +173,17 @@ class T5BasePolicy(Policy):
] ]
) )
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=T5Stack,
)
# optimization configuration # optimization configuration
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
...@@ -370,11 +378,19 @@ class T5ModelPolicy(T5BasePolicy): ...@@ -370,11 +378,19 @@ class T5ModelPolicy(T5BasePolicy):
policy = super().module_policy() policy = super().module_policy()
embedding_cls = None
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
if embedding_cls is not None:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="shared", suffix="shared",
target_module=VocabParallelEmbedding1D, target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
), ),
policy=policy, policy=policy,
target_key=T5Model, target_key=T5Model,
...@@ -406,17 +422,44 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy): ...@@ -406,17 +422,44 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
policy = super().module_policy() policy = super().module_policy()
embedding_cls = None
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
if embedding_cls is not None:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=SubModuleReplacementDescription(
SubModuleReplacementDescription( suffix="shared",
suffix="shared", target_module=embedding_cls,
target_module=VocabParallelEmbedding1D, kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
), ),
SubModuleReplacementDescription( policy=policy,
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) target_key=T5ForConditionalGeneration,
), )
],
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
target_module=VocabParallelLMHead1D,
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
),
policy=policy,
target_key=T5ForConditionalGeneration,
)
else:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
target_module=PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy, policy=policy,
target_key=T5ForConditionalGeneration, target_key=T5ForConditionalGeneration,
) )
...@@ -467,11 +510,19 @@ class T5EncoderPolicy(T5BasePolicy): ...@@ -467,11 +510,19 @@ class T5EncoderPolicy(T5BasePolicy):
policy = super().module_policy() policy = super().module_policy()
embedding_cls = None
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
if embedding_cls is not None:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="shared", suffix="shared",
target_module=VocabParallelEmbedding1D, target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
), ),
policy=policy, policy=policy,
target_key=T5EncoderModel, target_key=T5EncoderModel,
......
...@@ -45,11 +45,7 @@ class WhisperPolicy(Policy): ...@@ -45,11 +45,7 @@ class WhisperPolicy(Policy):
r""" r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size Reshape the Embedding layer to make the embedding dimension divisible by world_size
""" """
vocab_size = self.model.config.vocab_size self.tie_weight = self.tie_weight_check()
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self): def module_policy(self):
...@@ -63,6 +59,13 @@ class WhisperPolicy(Policy): ...@@ -63,6 +59,13 @@ class WhisperPolicy(Policy):
policy = {} policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
if self.shard_config.enable_fused_normalization: if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm norm_cls = col_nn.FusedLayerNorm
else: else:
...@@ -167,13 +170,17 @@ class WhisperPolicy(Policy): ...@@ -167,13 +170,17 @@ class WhisperPolicy(Policy):
], ],
) )
policy[WhisperDecoder] = ModulePolicyDescription( if embedding_cls is not None:
sub_module_replacement=[ self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="embed_tokens", suffix="embed_tokens",
target_module=col_nn.VocabParallelEmbedding1D, target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
), ),
] ],
policy=policy,
target_key=WhisperDecoder,
) )
# optimization configuration # optimization configuration
...@@ -280,8 +287,21 @@ class WhisperPolicy(Policy): ...@@ -280,8 +287,21 @@ class WhisperPolicy(Policy):
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="proj_out", suffix="proj_out",
target_module=col_nn.Linear1D_Col, target_module=col_nn.VocabParallelLMHead1D,
kwargs={"gather_output": True}, kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
),
policy=base_policy,
target_key=WhisperForConditionalGeneration,
)
else:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="proj_out",
target_module=col_nn.PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
), ),
policy=base_policy, policy=base_policy,
target_key=WhisperForConditionalGeneration, target_key=WhisperForConditionalGeneration,
...@@ -526,9 +546,6 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy): ...@@ -526,9 +546,6 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
# WhisperForAudioClassification # WhisperForAudioClassification
class WhisperForAudioClassificationPolicy(WhisperPolicy): class WhisperForAudioClassificationPolicy(WhisperPolicy):
def preprocess(self):
return self.model
def module_policy(self): def module_policy(self):
from transformers import WhisperForAudioClassification from transformers import WhisperForAudioClassification
......
...@@ -42,10 +42,9 @@ class ShardConfig: ...@@ -42,10 +42,9 @@ class ShardConfig:
sequence_parallelism_mode: str = None sequence_parallelism_mode: str = None
enable_sequence_overlap: bool = False enable_sequence_overlap: bool = False
parallel_output: bool = True parallel_output: bool = True
make_vocab_size_divisible_by: int = 64
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict) extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# TODO padding vocab
# make_vocab_size_divisible_by: int = 128
# pipeline_parallel_size: int # pipeline_parallel_size: int
# data_parallel_size: int # data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] # tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
......
...@@ -10,6 +10,7 @@ from colossalai.context.singleton_meta import SingletonMeta ...@@ -10,6 +10,7 @@ from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.d_tensor.comm_spec import * from colossalai.tensor.d_tensor.comm_spec import *
from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.misc import LayoutException from colossalai.tensor.d_tensor.misc import LayoutException
from colossalai.tensor.padded_tensor.api import init_as_padded_tensor, is_padded_tensor
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from .sharding_spec import ShardingSpec from .sharding_spec import ShardingSpec
...@@ -607,8 +608,18 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -607,8 +608,18 @@ class LayoutConverter(metaclass=SingletonMeta):
[3.], [3.],
[3.]]) [3.]])
""" """
_, comm_action_sequence = self.layout_converting(source_layout, target_layout) _, comm_action_sequence = self.layout_converting(source_layout, target_layout)
target_tensor = tensor
for comm_spec in comm_action_sequence: for comm_spec in comm_action_sequence:
tensor = comm_spec.covert_spec_to_action(tensor) target_tensor = comm_spec.covert_spec_to_action(target_tensor)
tensor.dist_layout = target_layout target_tensor.dist_layout = target_layout
return tensor
# restore the padding information
if is_padded_tensor(tensor) and not is_padded_tensor(target_tensor):
target_tensor = init_as_padded_tensor(
target_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim
)
return target_tensor
from .api import init_as_padded_tensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor
__all__ = ["is_padded_tensor", "to_padded_tensor", "to_unpadded_tensor", "init_as_padded_tensor"]
import torch
def _hijack_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
ptensor._unpad_detach = ptensor.detach
ptensor._unpad_clone = ptensor.clone
def new_detach(self):
t_ = self._unpad_detach()
t_._padding_dim = self._padding_dim
t_._origin_length = self._origin_length
t_._current_length = self._current_length
return t_
def new_clone(self, *args, **kwargs):
t_ = self._unpad_clone(*args, **kwargs)
t_._padding_dim = self._padding_dim
t_._origin_length = self._origin_length
t_._current_length = self._current_length
return t_
# bind the new methods to the tensor
ptensor.detach = new_detach.__get__(ptensor)
ptensor.clone = new_clone.__get__(ptensor)
return ptensor
def _hijack_back_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
ptensor.detach = ptensor._unpad_detach
ptensor.clone = ptensor._unpad_clone
delattr(ptensor, "_unpad_detach")
delattr(ptensor, "_unpad_clone")
return ptensor
def is_padded_tensor(tensor: torch.Tensor) -> bool:
"""
Check whether the given tensor is a padding tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: Whether the given tensor is a padding tensor.
"""
return hasattr(tensor, "_padding_dim")
def to_padded_tensor(
tensor: torch.Tensor,
current_length: int,
padding_dim: int,
) -> torch.Tensor:
assert (
padding_dim < tensor.dim()
), f"Please passing a valid padding_dim. the dimension of the tensor is {tensor.dim()}"
if is_padded_tensor(tensor):
return tensor
origin_length = tensor.shape[padding_dim]
padding_num = current_length - origin_length
padding_data = torch.zeros(
*tensor.shape[:padding_dim],
padding_num,
*tensor.shape[padding_dim + 1 :],
device=tensor.device,
dtype=tensor.dtype,
)
tensor.data = torch.cat((tensor.data, padding_data), dim=padding_dim).contiguous()
tensor._padding_dim = padding_dim
tensor._origin_length = origin_length
tensor._current_length = current_length
_hijack_detach_and_clone(tensor)
return tensor
def to_unpadded_tensor(ptensor: torch.Tensor):
if not is_padded_tensor(ptensor):
return ptensor
unpad_slices = [slice(None)] * ptensor.dim()
unpad_slices[ptensor._padding_dim] = slice(None, ptensor._origin_length)
ptensor.data = ptensor.data[tuple(unpad_slices)]
delattr(ptensor, "_padding_dim")
delattr(ptensor, "_origin_length")
delattr(ptensor, "_current_length")
_hijack_back_detach_and_clone(ptensor)
return ptensor
def init_as_padded_tensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int):
if is_padded_tensor(tensor):
return tensor
tensor._padding_dim = padding_dim
tensor._origin_length = origin_length
tensor._current_length = current_length
_hijack_detach_and_clone(tensor)
return tensor
...@@ -23,7 +23,7 @@ def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1 ...@@ -23,7 +23,7 @@ def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1
rtol=rtol, rtol=rtol,
atol=atol, atol=atol,
msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \ msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \
dtype: {a.dtype} vs {b.dtype}", dtype: {a.dtype} vs {b.dtype}",
) )
......
...@@ -27,6 +27,12 @@ from colossalai.tensor.d_tensor import ( ...@@ -27,6 +27,12 @@ from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor, is_customized_distributed_tensor,
is_distributed_tensor, is_distributed_tensor,
) )
from colossalai.tensor.padded_tensor import (
init_as_padded_tensor,
is_padded_tensor,
to_padded_tensor,
to_unpadded_tensor,
)
from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, is_ddp_ignored from colossalai.utils import _cast_float, free_storage, is_ddp_ignored
...@@ -460,6 +466,11 @@ class GeminiDDP(ModelWrapper): ...@@ -460,6 +466,11 @@ class GeminiDDP(ModelWrapper):
record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn
) )
record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu() record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu()
if is_padded_tensor(tensor):
record_tensor = init_as_padded_tensor(
record_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim
)
record_tensor = to_unpadded_tensor(record_tensor)
assert tensor not in chunk_to_save_data assert tensor not in chunk_to_save_data
chunk_to_save_data[tensor] = record_tensor chunk_to_save_data[tensor] = record_tensor
...@@ -520,6 +531,8 @@ class GeminiDDP(ModelWrapper): ...@@ -520,6 +531,8 @@ class GeminiDDP(ModelWrapper):
# deal with ddp ignored parameters # deal with ddp ignored parameters
destination[prefix + name] = param if keep_vars else param.detach() destination[prefix + name] = param if keep_vars else param.detach()
else: else:
if is_padded_tensor(p_mapping[param]):
p_mapping[param] = to_unpadded_tensor(p_mapping[param])
destination[prefix + name] = p_mapping[param] destination[prefix + name] = p_mapping[param]
del p_mapping del p_mapping
del param_to_save_data del param_to_save_data
...@@ -627,6 +640,7 @@ class GeminiDDP(ModelWrapper): ...@@ -627,6 +640,7 @@ class GeminiDDP(ModelWrapper):
list, and will be reported together in list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict` :meth:`~torch.nn.Module.load_state_dict`
""" """
for hook in self._load_state_dict_pre_hooks.values(): for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
...@@ -647,6 +661,14 @@ class GeminiDDP(ModelWrapper): ...@@ -647,6 +661,14 @@ class GeminiDDP(ModelWrapper):
if state_key in state_dict: if state_key in state_dict:
input_param = state_dict[state_key] input_param = state_dict[state_key]
global_shape = dest_tensor.shape
if source_device_mesh is not None and source_sharding_spec is not None:
global_shape = get_global_shape(dest_tensor)
if is_padded_tensor(dest_tensor):
padding_dim = dest_tensor._padding_dim
input_param = to_padded_tensor(input_param, global_shape[padding_dim], padding_dim)
if source_device_mesh is not None and source_sharding_spec is not None: if source_device_mesh is not None and source_sharding_spec is not None:
input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec) input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec)
elif shard_fn is not None and gather_fn is not None: elif shard_fn is not None and gather_fn is not None:
......
...@@ -21,12 +21,19 @@ from colossalai.tensor.d_tensor import ( ...@@ -21,12 +21,19 @@ from colossalai.tensor.d_tensor import (
distribute_tensor, distribute_tensor,
distribute_tensor_with_customization, distribute_tensor_with_customization,
get_device_mesh, get_device_mesh,
get_global_shape,
get_sharding_spec, get_sharding_spec,
init_as_dtensor, init_as_dtensor,
init_tensor_as_customization_distributed, init_tensor_as_customization_distributed,
is_customized_distributed_tensor, is_customized_distributed_tensor,
is_distributed_tensor, is_distributed_tensor,
) )
from colossalai.tensor.padded_tensor import (
init_as_padded_tensor,
is_padded_tensor,
to_padded_tensor,
to_unpadded_tensor,
)
from colossalai.utils import disposable, is_ddp_ignored from colossalai.utils import disposable, is_ddp_ignored
from .chunk import Chunk, ChunkManager from .chunk import Chunk, ChunkManager
...@@ -106,7 +113,7 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -106,7 +113,7 @@ class GeminiOptimizer(OptimizerWrapper):
max_norm: float = 0.0, max_norm: float = 0.0,
norm_type: float = 2.0, norm_type: float = 2.0,
tp_group: ProcessGroup = None, tp_group: ProcessGroup = None,
optimizer_params_info=None, params_info=None,
verbose: bool = False, verbose: bool = False,
**defaults: Any, **defaults: Any,
): ):
...@@ -124,7 +131,7 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -124,7 +131,7 @@ class GeminiOptimizer(OptimizerWrapper):
self.clipping_flag = max_norm > 0.0 self.clipping_flag = max_norm > 0.0
self.max_norm = max_norm self.max_norm = max_norm
self.tp_group = tp_group self.tp_group = tp_group
self.optimizer_params_info = optimizer_params_info self.params_info = params_info
self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1 self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0 self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
self.verbose = verbose self.verbose = verbose
...@@ -459,7 +466,7 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -459,7 +466,7 @@ class GeminiOptimizer(OptimizerWrapper):
is_customized_distributed = is_customized_distributed_tensor(param) is_customized_distributed = is_customized_distributed_tensor(param)
shard_spec = get_sharding_spec(param) if is_dtensor else None shard_spec = get_sharding_spec(param) if is_dtensor else None
device_mesh = get_device_mesh(param) if is_dtensor else None device_mesh = get_device_mesh(param) if is_dtensor else None
global_shape = self.optimizer_params_info["id2shape"][param_id] global_shape = self.params_info["id2shape"][param_id]
# If the chunk is kept gathered, # If the chunk is kept gathered,
# the parameters are treated the same as that of those in strict DDP during training. # the parameters are treated the same as that of those in strict DDP during training.
...@@ -477,6 +484,7 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -477,6 +484,7 @@ class GeminiOptimizer(OptimizerWrapper):
else: else:
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu() state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
if is_dtensor: if is_dtensor:
global_shape = get_global_shape(param)
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device) state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
state_tensor = init_as_dtensor( state_tensor = init_as_dtensor(
state_tensor, state_tensor,
...@@ -490,8 +498,13 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -490,8 +498,13 @@ class GeminiOptimizer(OptimizerWrapper):
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
) )
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
state_tensor = state_tensor.reshape(global_shape)
collected_states[state_name] = state_tensor.reshape(global_shape) if is_padded_tensor(param):
state_tensor = init_as_padded_tensor(
state_tensor, param._current_length, param._origin_length, param._padding_dim
)
state_tensor = to_unpadded_tensor(state_tensor)
collected_states[state_name] = state_tensor
return collected_states return collected_states
# Check whether the param with given id is managed by current process. # Check whether the param with given id is managed by current process.
...@@ -535,6 +548,7 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -535,6 +548,7 @@ class GeminiOptimizer(OptimizerWrapper):
if state_tensor.numel() == param.numel(): if state_tensor.numel() == param.numel():
collected_states[state_name] = torch.reshape(state_tensor, param.shape) collected_states[state_name] = torch.reshape(state_tensor, param.shape)
if is_dtensor: if is_dtensor:
global_shape = get_global_shape(param)
state_tensor = state_tensor.to(param.device) state_tensor = state_tensor.to(param.device)
state_tensor = init_as_dtensor( state_tensor = init_as_dtensor(
state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape
...@@ -545,6 +559,11 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -545,6 +559,11 @@ class GeminiOptimizer(OptimizerWrapper):
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
) )
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu() state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
if is_padded_tensor(param):
state_tensor = init_as_padded_tensor(
state_tensor, param._current_length, param._origin_length, param._padding_dim
)
state_tensor = to_unpadded_tensor(state_tensor)
return collected_states return collected_states
...@@ -698,7 +717,7 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -698,7 +717,7 @@ class GeminiOptimizer(OptimizerWrapper):
Load saved optimizer states into parameter with given id. Load saved optimizer states into parameter with given id.
""" """
def cast(param, state_range, value, key=None): def cast(param, state_range, value, global_shape, origin_shape, key=None):
""" """
Make a copy of the needed segment of value and cast it to device of param. Make a copy of the needed segment of value and cast it to device of param.
""" """
...@@ -714,7 +733,14 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -714,7 +733,14 @@ class GeminiOptimizer(OptimizerWrapper):
) )
if is_dtensor: if is_dtensor:
value = torch.reshape(value, global_shape) global_shape = get_global_shape(real_param)
if is_padded_tensor(real_param):
value = torch.reshape(value, origin_shape)
padding_dim = real_param._padding_dim
value = to_padded_tensor(value, global_shape[padding_dim], padding_dim)
if is_dtensor:
value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh) value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh)
elif is_customized_distributed: elif is_customized_distributed:
value = torch.reshape(value, global_shape) value = torch.reshape(value, global_shape)
...@@ -737,10 +763,11 @@ class GeminiOptimizer(OptimizerWrapper): ...@@ -737,10 +763,11 @@ class GeminiOptimizer(OptimizerWrapper):
is_customized_distributed = is_customized_distributed_tensor(real_param) is_customized_distributed = is_customized_distributed_tensor(real_param)
shard_spec = get_sharding_spec(real_param) if is_dtensor else None shard_spec = get_sharding_spec(real_param) if is_dtensor else None
device_mesh = get_device_mesh(real_param) if is_dtensor else None device_mesh = get_device_mesh(real_param) if is_dtensor else None
global_shape = self.optimizer_params_info["id2shape"][param_id] global_shape = self.params_info["id2shape"][param_id]
origin_shape = global_shape
for k, v in saved_states.items(): for k, v in saved_states.items():
updated_states[k] = cast(fake_param, state_range, v, k) updated_states[k] = cast(fake_param, state_range, v, global_shape, origin_shape, k)
del v # clean loaded states del v # clean loaded states
self.optim.state[fake_param].update(updated_states) self.optim.state[fake_param].update(updated_states)
......
...@@ -81,8 +81,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf ...@@ -81,8 +81,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
optimizer.backward(loss) optimizer.backward(loss)
optimizer.step() optimizer.step()
for group in optimizer.param_groups: optimizer.zero_grad()
group["lr"] = 0.1
with shared_tempdir() as tempdir: with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model" model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer" optimizer_ckpt_path = f"{tempdir}/optimizer"
......
...@@ -21,7 +21,7 @@ def check_vocab_embedding_1d(lazy_init: bool): ...@@ -21,7 +21,7 @@ def check_vocab_embedding_1d(lazy_init: bool):
dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding_copy, process_group=None) dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding_copy, process_group=None)
assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
assert dist_embedding_1d.num_embeddings == 64 assert dist_embedding_1d.num_embeddings == 128
assert dist_embedding_1d.embedding_dim == 32 assert dist_embedding_1d.embedding_dim == 32
assert embedding_copy.weight is dist_embedding_1d.weight assert embedding_copy.weight is dist_embedding_1d.weight
......
...@@ -14,12 +14,14 @@ from torch.testing import assert_close ...@@ -14,12 +14,14 @@ from torch.testing import assert_close
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer._utils import getattr_ from colossalai.shardformer._utils import getattr_
from colossalai.shardformer.policies.auto_policy import Policy from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.tensor.padded_tensor.api import is_padded_tensor, to_unpadded_tensor
def build_model( def build_model(
...@@ -247,11 +249,10 @@ def check_weight( ...@@ -247,11 +249,10 @@ def check_weight(
continue continue
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
sharded_weight_list = [ sharded_weight = gather_distributed_param(sharded_weight, keep_vars=False)
torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group))
] if is_padded_tensor(sharded_weight):
dist.all_gather(sharded_weight_list, sharded_weight, tp_group) sharded_weight = to_unpadded_tensor(sharded_weight)
sharded_weight = torch.cat(sharded_weight_list, dim=dim)
if verbose and dist.get_rank() == 0: if verbose and dist.get_rank() == 0:
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
......
...@@ -73,7 +73,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -73,7 +73,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check weights # check weights
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
atol, rtol = 5e-4, 1e-3 # TODO he precision in weight checking is too significant.
atol, rtol = 1e-3, 1e-3
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage(): if stage_manager is None or stage_manager.is_first_stage():
......
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, is_distributed_tensor, to_global
from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_padded_tensor(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
original_tensor = torch.rand(32, 64).to("cuda")
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec)
padded_tensor = to_padded_tensor(d_tensor, current_length=64, padding_dim=0)
assert padded_tensor.dist_layout == d_tensor.dist_layout
tensor_copy = padded_tensor.clone()
assert is_padded_tensor(tensor_copy)
assert is_distributed_tensor(tensor_copy)
tensor_detached = padded_tensor.detach()
assert is_padded_tensor(tensor_detached)
assert is_distributed_tensor(tensor_detached)
unpadded_tensor = to_unpadded_tensor(padded_tensor)
assert unpadded_tensor.shape == d_tensor.shape
assert is_distributed_tensor(unpadded_tensor)
global_tensor = to_global(unpadded_tensor)
assert global_tensor.shape == original_tensor.shape
@rerun_if_address_is_in_use()
def test_padded_tensor():
world_size = 4
spawn(check_padded_tensor, world_size)
if __name__ == "__main__":
test_padded_tensor()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment