Unverified Commit 19e1a5cf authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[shardformer] update colo attention to support custom mask (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests
parent 9a3321e9
...@@ -9,7 +9,12 @@ from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col ...@@ -9,7 +9,12 @@ from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col
from .._utils import getattr_ from .._utils import getattr_
from ..modeling.jit import get_jit_fused_dropout_add_func from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.opt import OPTPipelineForwards, get_jit_fused_opt_decoder_layer_forward, get_opt_flash_attention_forward from ..modeling.opt import (
OPTPipelineForwards,
get_jit_fused_opt_decoder_layer_forward,
get_opt_decoder_forward_for_flash_attention,
get_opt_flash_attention_forward,
)
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [ __all__ = [
...@@ -27,6 +32,7 @@ class OPTPolicy(Policy): ...@@ -27,6 +32,7 @@ class OPTPolicy(Policy):
import transformers import transformers
from packaging.version import Version from packaging.version import Version
# TODO: remove this version check when transformers>=4.36.0
assert Version(transformers.__version__) <= Version( assert Version(transformers.__version__) <= Version(
"4.33.0" "4.33.0"
), "The OPT model should run on a transformers version not greater than 4.33.0." ), "The OPT model should run on a transformers version not greater than 4.33.0."
...@@ -111,7 +117,9 @@ class OPTPolicy(Policy): ...@@ -111,7 +117,9 @@ class OPTPolicy(Policy):
# optimization configuration # optimization configuration
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True suffix="final_layer_norm",
target_module=norm_cls,
ignore_if_not_exist=True,
), ),
policy=policy, policy=policy,
target_key=OPTDecoder, target_key=OPTDecoder,
...@@ -119,10 +127,14 @@ class OPTPolicy(Policy): ...@@ -119,10 +127,14 @@ class OPTPolicy(Policy):
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=[ description=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attn_layer_norm", target_module=norm_cls, ignore_if_not_exist=True suffix="self_attn_layer_norm",
target_module=norm_cls,
ignore_if_not_exist=True,
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="final_layer_norm", target_module=norm_cls, ignore_if_not_exist=True suffix="final_layer_norm",
target_module=norm_cls,
ignore_if_not_exist=True,
), ),
], ],
policy=policy, policy=policy,
...@@ -133,11 +145,19 @@ class OPTPolicy(Policy): ...@@ -133,11 +145,19 @@ class OPTPolicy(Policy):
if self.shard_config.enable_flash_attention: if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={ description={
"forward": get_opt_flash_attention_forward(), "forward": get_opt_flash_attention_forward(self.shard_config),
}, },
policy=policy, policy=policy,
target_key=OPTAttention, target_key=OPTAttention,
) )
if not self.shard_config.pipeline_stage_manager:
self.append_or_create_method_replacement(
description={
"forward": get_opt_decoder_forward_for_flash_attention(self.shard_config),
},
policy=policy,
target_key=OPTDecoder,
)
# use jit fused operator # use jit fused operator
if self.shard_config.enable_jit_fused: if self.shard_config.enable_jit_fused:
...@@ -190,7 +210,14 @@ class OPTPolicy(Policy): ...@@ -190,7 +210,14 @@ class OPTPolicy(Policy):
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages) layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} method_replacement = {
"forward": partial(
new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config,
)
}
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls description=method_replacement, policy=policy, target_key=model_cls
) )
...@@ -203,7 +230,9 @@ class OPTModelPolicy(OPTPolicy): ...@@ -203,7 +230,9 @@ class OPTModelPolicy(OPTPolicy):
policy = super().module_policy() policy = super().module_policy()
if self.pipeline_stage_manager: if self.pipeline_stage_manager:
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=OPTModel, new_forward=OPTPipelineForwards.opt_model_forward, policy=policy model_cls=OPTModel,
new_forward=OPTPipelineForwards.opt_model_forward,
policy=policy,
) )
return policy return policy
...@@ -223,14 +252,18 @@ class OPTForCausalLMPolicy(OPTPolicy): ...@@ -223,14 +252,18 @@ class OPTForCausalLMPolicy(OPTPolicy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
), ),
policy=policy, policy=policy,
target_key=OPTForCausalLM, target_key=OPTForCausalLM,
) )
if self.pipeline_stage_manager: if self.pipeline_stage_manager:
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=OPTForCausalLM, new_forward=OPTPipelineForwards.opt_for_causal_lm_forward, policy=policy model_cls=OPTForCausalLM,
new_forward=OPTPipelineForwards.opt_for_causal_lm_forward,
policy=policy,
) )
return policy return policy
...@@ -246,7 +279,12 @@ class OPTForCausalLMPolicy(OPTPolicy): ...@@ -246,7 +279,12 @@ class OPTForCausalLMPolicy(OPTPolicy):
if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1: if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
num_stages = self.pipeline_stage_manager.num_stages num_stages = self.pipeline_stage_manager.num_stages
if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight): if id(opt_model.model.decoder.embed_tokens.weight) == id(opt_model.lm_head.weight):
return [{0: opt_model.model.decoder.embed_tokens.weight, num_stages - 1: opt_model.lm_head.weight}] return [
{
0: opt_model.model.decoder.embed_tokens.weight,
num_stages - 1: opt_model.lm_head.weight,
}
]
return [] return []
def postprocess(self): def postprocess(self):
......
...@@ -13,6 +13,7 @@ from ..modeling.whisper import ( ...@@ -13,6 +13,7 @@ from ..modeling.whisper import (
WhisperPipelineForwards, WhisperPipelineForwards,
get_jit_fused_whisper_decoder_layer_forward, get_jit_fused_whisper_decoder_layer_forward,
get_jit_fused_whisper_encoder_layer_forward, get_jit_fused_whisper_encoder_layer_forward,
get_whisper_decoder_forward_for_flash_attention,
get_whisper_flash_attention_forward, get_whisper_flash_attention_forward,
) )
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
...@@ -31,6 +32,7 @@ class WhisperPolicy(Policy): ...@@ -31,6 +32,7 @@ class WhisperPolicy(Policy):
import transformers import transformers
from packaging.version import Version from packaging.version import Version
# TODO: remove this version check when transformers>=4.36.0
assert Version(transformers.__version__) <= Version( assert Version(transformers.__version__) <= Version(
"4.33.0" "4.33.0"
), "The Whisper model should run on a transformers version not greater than 4.33.0." ), "The Whisper model should run on a transformers version not greater than 4.33.0."
...@@ -240,6 +242,14 @@ class WhisperPolicy(Policy): ...@@ -240,6 +242,14 @@ class WhisperPolicy(Policy):
policy=policy, policy=policy,
target_key=WhisperAttention, target_key=WhisperAttention,
) )
if not self.shard_config.pipeline_stage_manager:
self.append_or_create_method_replacement(
description={
"forward": get_whisper_decoder_forward_for_flash_attention(self.shard_config),
},
policy=policy,
target_key=WhisperDecoder,
)
# use jit fused operator # use jit fused operator
if self.shard_config.enable_jit_fused: if self.shard_config.enable_jit_fused:
...@@ -269,7 +279,9 @@ class WhisperPolicy(Policy): ...@@ -269,7 +279,9 @@ class WhisperPolicy(Policy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription( description=SubModuleReplacementDescription(
suffix="proj_out", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} suffix="proj_out",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True},
), ),
policy=base_policy, policy=base_policy,
target_key=WhisperForConditionalGeneration, target_key=WhisperForConditionalGeneration,
...@@ -326,7 +338,10 @@ class WhisperPolicy(Policy): ...@@ -326,7 +338,10 @@ class WhisperPolicy(Policy):
if stage < decoder_starting_stage: if stage < decoder_starting_stage:
return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
else: else:
return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) return Policy.get_stage_index(
layers_per_stage[decoder_starting_stage:],
stage - decoder_starting_stage,
)
def get_held_layers(self) -> List[nn.Module]: def get_held_layers(self) -> List[nn.Module]:
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
...@@ -422,6 +437,7 @@ class WhisperPolicy(Policy): ...@@ -422,6 +437,7 @@ class WhisperPolicy(Policy):
stage_manager=stage_manager, stage_manager=stage_manager,
stage_index=stage_index, stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage, decoder_starting_stage=decoder_starting_stage,
shard_config=self.shard_config,
) )
} }
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
...@@ -436,7 +452,9 @@ class WhisperModelPolicy(WhisperPolicy): ...@@ -436,7 +452,9 @@ class WhisperModelPolicy(WhisperPolicy):
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=WhisperModel, new_forward=WhisperPipelineForwards.whisper_model_forward, policy=policy model_cls=WhisperModel,
new_forward=WhisperPipelineForwards.whisper_model_forward,
policy=policy,
) )
return policy return policy
......
...@@ -40,7 +40,12 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): ...@@ -40,7 +40,12 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None):
assert torch.all(a == b), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}" assert torch.all(a == b), f"expected tensors on rank {i} and {i + 1} to be equal but they are not, {a} vs {b}"
def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True, ignore_dtype: bool = False): def check_state_dict_equal(
d1: OrderedDict,
d2: OrderedDict,
ignore_device: bool = True,
ignore_dtype: bool = False,
):
assert len(list(d1.keys())) == len( assert len(list(d1.keys())) == len(
list(d2.keys()) list(d2.keys())
), f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}" ), f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}"
...@@ -94,7 +99,12 @@ def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_devic ...@@ -94,7 +99,12 @@ def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_devic
def assert_hf_output_close( def assert_hf_output_close(
out1: Any, out2: Any, ignore_keys: List[str] = None, track_name: str = "", atol=1e-5, rtol=1e-5 out1: Any,
out2: Any,
ignore_keys: List[str] = None,
track_name: str = "",
atol=1e-5,
rtol=1e-5,
): ):
""" """
Check if two outputs from huggingface are equal. Check if two outputs from huggingface are equal.
...@@ -113,7 +123,12 @@ def assert_hf_output_close( ...@@ -113,7 +123,12 @@ def assert_hf_output_close(
if ignore_keys is not None and k in ignore_keys: if ignore_keys is not None and k in ignore_keys:
continue continue
assert_hf_output_close( assert_hf_output_close(
out1[k], out2[k], track_name=f"{track_name}.{k}", ignore_keys=ignore_keys, atol=atol, rtol=rtol out1[k],
out2[k],
track_name=f"{track_name}.{k}",
ignore_keys=ignore_keys,
atol=atol,
rtol=rtol,
) )
elif isinstance(out1, (list, tuple)) and isinstance(out2, (list, tuple)): elif isinstance(out1, (list, tuple)) and isinstance(out2, (list, tuple)):
# if two values are list # if two values are list
...@@ -121,12 +136,17 @@ def assert_hf_output_close( ...@@ -121,12 +136,17 @@ def assert_hf_output_close(
assert len(out1) == len(out2) assert len(out1) == len(out2)
for i in range(len(out1)): for i in range(len(out1)):
assert_hf_output_close( assert_hf_output_close(
out1[i], out2[i], track_name=f"{track_name}.{i}", ignore_keys=ignore_keys, atol=atol, rtol=rtol out1[i],
out2[i],
track_name=f"{track_name}.{i}",
ignore_keys=ignore_keys,
atol=atol,
rtol=rtol,
) )
elif isinstance(out1, Tensor) and isinstance(out2, Tensor): elif isinstance(out1, Tensor) and isinstance(out2, Tensor):
if out1.shape != out2.shape: if out1.shape != out2.shape:
raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}") raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}")
assert torch.allclose( assert_close(
out1, out2, atol=atol, rtol=rtol out1, out2, atol=atol, rtol=rtol
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, \nmean error: {torch.abs(out1 - out2).mean()}" ), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, \nmean error: {torch.abs(out1 - out2).mean()}"
else: else:
......
...@@ -101,13 +101,13 @@ class MyExtension(_Extension): ...@@ -101,13 +101,13 @@ class MyExtension(_Extension):
self._support_jit = True self._support_jit = True
self.priority = 10 self.priority = 10
def is_hardware_available(self) -> bool: def is_available(self) -> bool:
""" """
Return if the required hardware can be found. Return if the required hardware can be found.
""" """
... ...
def assert_hardware_compatible(self) -> None: def assert_compatible(self) -> None:
""" """
Check if the hardware required by the kernel is compatible. Check if the hardware required by the kernel is compatible.
""" """
......
from .cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension from .cpu_adam import CpuAdamArmExtension, CpuAdamX86Extension
from .flash_attention import ( from .flash_attention import FlashAttentionDaoCudaExtension, FlashAttentionNpuExtension, FlashAttentionSdpaCudaExtension
FlashAttentionDaoCudaExtension,
FlashAttentionNpuExtension,
FlashAttentionXformersCudaExtension,
)
from .layernorm import LayerNormCudaExtension from .layernorm import LayerNormCudaExtension
from .moe import MoeCudaExtension from .moe import MoeCudaExtension
from .optimizer import FusedOptimizerCudaExtension from .optimizer import FusedOptimizerCudaExtension
...@@ -18,7 +14,7 @@ ALL_EXTENSIONS = [ ...@@ -18,7 +14,7 @@ ALL_EXTENSIONS = [
ScaledMaskedSoftmaxCudaExtension, ScaledMaskedSoftmaxCudaExtension,
ScaledUpperTriangleMaskedSoftmaxCudaExtension, ScaledUpperTriangleMaskedSoftmaxCudaExtension,
FlashAttentionDaoCudaExtension, FlashAttentionDaoCudaExtension,
FlashAttentionXformersCudaExtension, FlashAttentionSdpaCudaExtension,
FlashAttentionNpuExtension, FlashAttentionNpuExtension,
] ]
...@@ -31,6 +27,6 @@ __all__ = [ ...@@ -31,6 +27,6 @@ __all__ = [
"ScaledMaskedSoftmaxCudaExtension", "ScaledMaskedSoftmaxCudaExtension",
"ScaledUpperTriangleMaskedSoftmaxCudaExtension", "ScaledUpperTriangleMaskedSoftmaxCudaExtension",
"FlashAttentionDaoCudaExtension", "FlashAttentionDaoCudaExtension",
"FlashAttentionXformersCudaExtension", "FlashAttentionSdpaCudaExtension",
"FlashAttentionNpuExtension", "FlashAttentionNpuExtension",
] ]
...@@ -58,13 +58,13 @@ class _Extension(ABC): ...@@ -58,13 +58,13 @@ class _Extension(ABC):
return cache_directory return cache_directory
@abstractmethod @abstractmethod
def is_hardware_available(self) -> bool: def is_available(self) -> bool:
""" """
Check if the hardware required by the kernel is available. Check if the hardware required by the kernel is available.
""" """
@abstractmethod @abstractmethod
def assert_hardware_compatible(self) -> None: def assert_compatible(self) -> None:
""" """
Check if the hardware required by the kernel is compatible. Check if the hardware required by the kernel is compatible.
""" """
......
...@@ -7,11 +7,11 @@ class CpuAdamArmExtension(_CppExtension): ...@@ -7,11 +7,11 @@ class CpuAdamArmExtension(_CppExtension):
def __init__(self): def __init__(self):
super().__init__(name="cpu_adam_arm") super().__init__(name="cpu_adam_arm")
def is_hardware_available(self) -> bool: def is_available(self) -> bool:
# only arm allowed # only arm allowed
return platform.machine() == "aarch64" return platform.machine() == "aarch64"
def assert_hardware_compatible(self) -> None: def assert_compatible(self) -> None:
arch = platform.machine() arch = platform.machine()
assert ( assert (
arch == "aarch64" arch == "aarch64"
......
...@@ -8,15 +8,15 @@ class CpuAdamX86Extension(_CudaExtension): ...@@ -8,15 +8,15 @@ class CpuAdamX86Extension(_CudaExtension):
def __init__(self): def __init__(self):
super().__init__(name="cpu_adam_x86") super().__init__(name="cpu_adam_x86")
def is_hardware_available(self) -> bool: def is_available(self) -> bool:
return platform.machine() == "x86_64" and super().is_hardware_available() return platform.machine() == "x86_64" and super().is_available()
def assert_hardware_compatible(self) -> None: def assert_compatible(self) -> None:
arch = platform.machine() arch = platform.machine()
assert ( assert (
arch == "x86_64" arch == "x86_64"
), f"[extension] The {self.name} kernel requires the CPU architecture to be x86_64 but got {arch}" ), f"[extension] The {self.name} kernel requires the CPU architecture to be x86_64 but got {arch}"
super().assert_hardware_compatible() super().assert_compatible()
# necessary 4 functions # necessary 4 functions
def sources_files(self): def sources_files(self):
......
...@@ -22,7 +22,7 @@ class _CudaExtension(_CppExtension): ...@@ -22,7 +22,7 @@ class _CudaExtension(_CppExtension):
This function should return a list of nvcc compilation flags for extensions. This function should return a list of nvcc compilation flags for extensions.
""" """
def is_hardware_available(self) -> bool: def is_available(self) -> bool:
# cuda extension can only be built if cuda is available # cuda extension can only be built if cuda is available
try: try:
import torch import torch
...@@ -32,7 +32,7 @@ class _CudaExtension(_CppExtension): ...@@ -32,7 +32,7 @@ class _CudaExtension(_CppExtension):
cuda_available = False cuda_available = False
return cuda_available return cuda_available
def assert_hardware_compatible(self) -> None: def assert_compatible(self) -> None:
from torch.utils.cpp_extension import CUDA_HOME from torch.utils.cpp_extension import CUDA_HOME
if not CUDA_HOME: if not CUDA_HOME:
......
from .flash_attention_dao_cuda import FlashAttentionDaoCudaExtension from .flash_attention_dao_cuda import FlashAttentionDaoCudaExtension
from .flash_attention_npu import FlashAttentionNpuExtension from .flash_attention_npu import FlashAttentionNpuExtension
from .flash_attention_xformers_cuda import FlashAttentionXformersCudaExtension from .flash_attention_sdpa_cuda import FlashAttentionSdpaCudaExtension
try: try:
# TODO: remove this after updating openmoe example
import flash_attention # noqa import flash_attention # noqa
HAS_FLASH_ATTN = True HAS_FLASH_ATTN = True
except: except:
HAS_FLASH_ATTN = False HAS_FLASH_ATTN = False
try:
import xformers # noqa
HAS_MEM_EFF_ATTN = True
except:
HAS_MEM_EFF_ATTN = False
__all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionXformersCudaExtension", "FlashAttentionNpuExtension"] __all__ = ["FlashAttentionDaoCudaExtension", "FlashAttentionSdpaCudaExtension", "FlashAttentionNpuExtension"]
...@@ -5,17 +5,20 @@ class FlashAttentionDaoCudaExtension(_Extension): ...@@ -5,17 +5,20 @@ class FlashAttentionDaoCudaExtension(_Extension):
def __init__(self): def __init__(self):
super().__init__(name="flash_attention_dao_cuda", support_aot=False, support_jit=False, priority=10) super().__init__(name="flash_attention_dao_cuda", support_aot=False, support_jit=False, priority=10)
def is_hardware_available(self) -> bool: def is_available(self) -> bool:
# cuda extension can only be built if cuda is available # cuda extension can only be built if cuda is available
try: try:
import torch import torch
from flash_attn import flash_attn_func, flash_attn_varlen_kvpacked_func # noqa
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
cuda_available = torch.cuda.is_available() cuda_available = torch.cuda.is_available()
except: except:
cuda_available = False cuda_available = False
return cuda_available return cuda_available
def assert_hardware_compatible(self) -> bool: def assert_compatible(self) -> bool:
pass pass
def build_aot(self) -> None: def build_aot(self) -> None:
...@@ -29,65 +32,65 @@ class FlashAttentionDaoCudaExtension(_Extension): ...@@ -29,65 +32,65 @@ class FlashAttentionDaoCudaExtension(_Extension):
) )
def load(self): def load(self):
try:
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
except ImportError:
raise ModuleNotFoundError(
(
"We rely on the third-party flash-attn library for flash attention. Please install flash-attn via 'pip install flash-attn --no-build-isolation'"
)
)
from typing import Optional from typing import Optional
import torch import torch
from einops import rearrange
from flash_attn import flash_attn_func, flash_attn_varlen_kvpacked_func
from flash_attn.bert_padding import index_first_axis, pad_input
def _unpad_input(hidden_states: torch.Tensor, indices: torch.Tensor):
return index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices)
def flash_attention( def flash_attention(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
seq_len_info_q: "SeqLenInfo",
seq_len_info_kv: "SeqLenInfo",
origin_attn_mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
dropout_p: float = 0.0, dropout_p: float = 0.0,
scale: float = None, scale: Optional[float] = None,
causal: bool = False, attention_mask: Optional[torch.Tensor] = None,
padded: bool = False, is_causal: bool = False,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
q_indices: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
): ):
""" # [B, N, S, D] -> [B, S, N, D]
Arguments: q = q.transpose(1, 2)
q: (batch, q_seqlen, nheads, headdim) k = k.transpose(1, 2)
k: (batch, kv_seqlen, nheads, headdim) v = v.transpose(1, 2)
v: (batch, kv_seqlen, nheads, headdim) b, s_q = q.shape[:2]
batch_size: int. if cu_seqlens_q is not None:
seq_len: int. # padded / padded causal
dropout_p: float. Dropout probability. # unpad input: [B, S, N, D] -> [T, N, D]
sm_scale: float. The scaling of QK^T before applying softmax. q = _unpad_input(q, q_indices)
Default to 1 / sqrt(headdim). kv = _unpad_input(torch.stack(tensors=(k, v), dim=2), kv_indices)
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). attn_output = flash_attn_varlen_kvpacked_func(
Return: q,
attn_out: (batch, q_seqlen, nheads, headdim). kv,
""" cu_seqlens_q,
# check if the input is in allowed dtypes cu_seqlens_kv,
if padded: max_seqlen_q,
if seq_len_info_kv == None: max_seqlen_kv,
seq_len_info_kv = seq_len_info_q dropout_p=dropout_p,
softmax_scale=scale,
attn_out = flash_attn_varlen_func( causal=is_causal,
)
# pad output: [T, N, D] -> [B, S, N, D]
attn_output = pad_input(attn_output, q_indices, b, s_q)
else:
# causal / no attn mask
attn_output = flash_attn_func(
q, q,
k, k,
v, v,
seq_len_info_q.cu_seqlens, dropout_p=dropout_p,
seq_len_info_kv.cu_seqlens, softmax_scale=scale,
seq_len_info_q.max_seqlen, causal=is_causal,
seq_len_info_kv.max_seqlen,
dropout_p,
scale,
causal,
) )
else: # [B, S, N, D] -> [B, N, S, D]
attn_out = flash_attn_func(q, k, v, dropout_p=dropout_p, softmax_scale=scale, causal=causal) return attn_output.transpose(1, 2)
return attn_out
return flash_attention return flash_attention
...@@ -5,15 +5,15 @@ class FlashAttentionNpuExtension(_Extension): ...@@ -5,15 +5,15 @@ class FlashAttentionNpuExtension(_Extension):
def __init__(self): def __init__(self):
super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False) super().__init__(name="flash_attention_npu", support_aot=False, support_jit=False)
def is_hardware_available(self) -> bool: def is_available(self) -> bool:
try: try:
import torch_npu # noqa import torch_npu
return True return hasattr(torch_npu, "npu_fusion_attention")
except: except:
return False return False
def assert_hardware_compatible(self) -> bool: def assert_compatible(self) -> bool:
pass pass
def build_aot(self) -> None: def build_aot(self) -> None:
...@@ -27,47 +27,36 @@ class FlashAttentionNpuExtension(_Extension): ...@@ -27,47 +27,36 @@ class FlashAttentionNpuExtension(_Extension):
) )
def load(self): def load(self):
from typing import Optional
import torch import torch
from einops import rearrange import torch_npu
def npu_sdpa_attention( def flash_attention(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
seq_len_info_q=None,
seq_len_info_kv=None,
origin_attn_mask: torch.Tensor = None,
dropout_p: float = 0.0, dropout_p: float = 0.0,
scale: float = 1.0, scale: Optional[float] = None,
causal=None, attention_mask: Optional[torch.Tensor] = None,
padded=None, is_causal: bool = False,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
q_indices: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
): ):
""" num_heads = q.size(1)
The scaled dot product attention. return torch_npu.npu_fusion_attention(
Arguments:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
batch_size: int.
seq_len: int.
dropout_p: float. Dropout probability.
scale: float. The scaling of QK^T before applying softmax.
Default to 1.
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
q, k, v = [rearrange(x, "b s h d -> b h s d").contiguous() for x in (q, k, v)]
output = torch.nn.functional.scaled_dot_product_attention(
q, q,
k, k,
v, v,
attn_mask=origin_attn_mask, num_heads,
dropout_p=dropout_p, "BNSD",
is_causal=origin_attn_mask is None, atten_mask=attention_mask.bool(),
scale=scale, scale=scale,
) keep_prob=1 - dropout_p,
output = rearrange(output, "b h s d -> b s (h d)") )[0]
return output
return npu_sdpa_attention return flash_attention
from ..base_extension import _Extension
class FlashAttentionSdpaCudaExtension(_Extension):
def __init__(self):
super().__init__(name="flash_attention_sdpa_cuda", support_aot=False, support_jit=False)
def is_available(self) -> bool:
# cuda extension can only be built if cuda is available
try:
import torch
cuda_available = torch.cuda.is_available()
except:
cuda_available = False
return cuda_available
def assert_compatible(self) -> bool:
pass
def build_aot(self) -> None:
raise NotImplementedError("Flash attention SDPA does not require ahead-of-time compilation.")
def build_jit(self) -> None:
raise NotImplementedError("Flash attention SDPA does not require just-in-time compilation.")
def load(self):
from typing import Optional
import torch
def flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
dropout_p: float = 0.0,
scale: Optional[float] = None,
attention_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
q_indices: Optional[torch.Tensor] = None,
kv_indices: Optional[torch.Tensor] = None,
):
return torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attention_mask,
dropout_p=dropout_p,
scale=scale,
)
return flash_attention
from ..base_extension import _Extension
class FlashAttentionXformersCudaExtension(_Extension):
def __init__(self):
super().__init__(name="flash_attention_xformers_cuda", support_aot=False, support_jit=False)
def is_hardware_available(self) -> bool:
# cuda extension can only be built if cuda is available
try:
import torch
cuda_available = torch.cuda.is_available()
except:
cuda_available = False
return cuda_available
def assert_hardware_compatible(self) -> bool:
pass
def build_aot(self) -> None:
raise NotImplementedError(
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
)
def build_jit(self) -> None:
raise NotImplementedError(
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
)
def load(self):
try:
from xformers.ops.fmha import MemoryEfficientAttentionCutlassOp, memory_efficient_attention
from xformers.ops.fmha.attn_bias import (
BlockDiagonalCausalMask,
BlockDiagonalMask,
LowerTriangularMask,
LowerTriangularMaskWithTensorBias,
)
except ImportError:
raise ModuleNotFoundError(
(
"We rely on the third-party xformers library for flash attention (https://github.com/facebookresearch/xformers). Please install xformers according to the GitHub Readme."
)
)
from typing import Optional
import torch
allow_alibi = True
for op in MemoryEfficientAttentionCutlassOp:
allow_alibi = allow_alibi & (LowerTriangularMaskWithTensorBias in op.SUPPORTED_ATTN_BIAS_TYPES)
def mem_eff_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_len_info_q: "SeqLenInfo",
seq_len_info_kv: "SeqLenInfo",
origin_attn_mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: float = None,
causal: bool = False,
padded: bool = False,
):
attn_bias = None
if padded: # bert style
if not causal:
attn_bias = BlockDiagonalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
else:
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_len_info_q.seqlens, seq_len_info_kv.seqlens)
elif causal: # gpt style
attn_bias = LowerTriangularMask()
if bias is not None: # alibi / relative position embedding
assert allow_alibi, "flash attention with bias is not supported in this system."
assert causal, "attention with bias is only supported for causal attention so far."
attn_bias = attn_bias.add_bias(bias)
if padded:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
out = memory_efficient_attention(q, k, v, attn_bias=attn_bias, p=dropout_p, scale=scale)
# shape: (b*s, n, d)
if padded:
out = out.squeeze(0)
return out
return mem_eff_attention
...@@ -80,8 +80,8 @@ if BUILD_EXT: ...@@ -80,8 +80,8 @@ if BUILD_EXT:
for ext_cls in ALL_EXTENSIONS: for ext_cls in ALL_EXTENSIONS:
ext = ext_cls() ext = ext_cls()
if ext.support_aot and ext.is_hardware_available(): if ext.support_aot and ext.is_available():
ext.assert_hardware_compatible() ext.assert_compatible()
op_names.append(ext.name) op_names.append(ext.name)
ext_modules.append(ext.build_aot()) ext_modules.append(ext.build_aot())
......
import math
from copy import copy
import torch
from torch.testing import assert_close
from colossalai.kernel.kernel_loader import (
FlashAttentionLoader,
FlashAttentionWithCustomMaskLoader,
FlashAttentionWithPaddingMaskLoader,
)
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer.attn import invert_mask
from colossalai.testing import clear_cache_before_run, parameterize
from colossalai.utils import get_current_device, set_seed
DTYPE = [torch.float16, torch.bfloat16]
B, N, S, D = 2, 8, 256, 32
TOL_MAP = {
torch.float16: {"atol": 5e-4, "rtol": 2e-3},
torch.bfloat16: {},
}
def attention_ref(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask=None, dropout_p=0.0):
head_dim = q.size(-1)
attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(head_dim)
if attn_mask is not None:
attn_weights = attn_weights + attn_mask
attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float).to(q.dtype)
attn_weights = torch.dropout(attn_weights, p=dropout_p, train=True)
attn_output = torch.matmul(attn_weights, v)
return attn_output
def gen_padded_kwargs(dtype: torch.dtype):
padding_mask = torch.ones((B, S), dtype=torch.int, device=get_current_device())
padding_mask[0, : S // 4] = 0
return (
ColoAttention.prepare_attn_kwargs((B, 1, S, S), dtype, padding_mask.device, q_padding_mask=padding_mask),
padding_mask,
)
def gen_padded_causal_kwargs(dtype: torch.dtype):
padding_mask = torch.ones((B, S), dtype=torch.int, device=get_current_device())
padding_mask[0, S // 2 :] = 0
return (
ColoAttention.prepare_attn_kwargs(
(B, 1, S, S), dtype, padding_mask.device, q_padding_mask=padding_mask, is_causal=True
),
padding_mask,
)
def gen_causal_kwargs(dtype: torch.dtype):
return ColoAttention.prepare_attn_kwargs((B, 1, S, S), dtype, get_current_device(), is_causal=True), None
def gen_custom_kwargs(dtype: torch.dtype):
attn_mask = torch.ones((B, S, S), dtype=dtype, device=get_current_device())
attn_mask[0, : S // 2, S // 2 :] = 0
attn_mask[0, S // 2 :, : S // 2] = 0
attn_mask[1, :, S // 4 :] = 0
attn_mask = invert_mask(attn_mask).unsqueeze(1)
assert not torch.all(attn_mask != 0, dim=-1).any()
return {"attention_mask": attn_mask}, None
def post_process_kwargs_for_raw_attn(attn_kwargs: dict):
if "attention_mask_type" in attn_kwargs:
attn_kwargs = copy(attn_kwargs)
mask_type = attn_kwargs.pop("attention_mask_type")
attn_kwargs["is_causal"] = mask_type in (AttnMaskType.CAUSAL, AttnMaskType.PADDED_CAUSAL)
return attn_kwargs
def check_attn_func(dtype: torch.dtype, attn_func, attn_kwargs: dict, padding_mask=None):
tols = TOL_MAP[dtype]
q = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True)
k = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True)
v = torch.rand((B, N, S, D), dtype=dtype, device=get_current_device(), requires_grad=True)
q_flash = q.clone().detach().requires_grad_(True)
k_flash = k.clone().detach().requires_grad_(True)
v_flash = v.clone().detach().requires_grad_(True)
attn_mask = attn_kwargs.get("attention_mask", None)
ref_output = attention_ref(q, k, v, attn_mask)
output = attn_func(q_flash, k_flash, v_flash, **attn_kwargs)
if padding_mask is not None:
# [B, Sq] -> [B, 1, Sq, 1]
padding_mask = padding_mask[:, None, :, None].logical_not()
ref_output = ref_output.masked_fill(padding_mask, 0)
output = output.masked_fill(padding_mask, 0)
assert_close(output, ref_output, **tols)
output.mean().backward()
ref_output.mean().backward()
assert_close(q.grad, q_flash.grad, **tols)
assert_close(k.grad, k_flash.grad, **tols)
assert_close(v.grad, v_flash.grad, **tols)
@clear_cache_before_run()
@parameterize("dtype", DTYPE)
def test_flash_attn_func(dtype: torch.dtype):
torch.backends.cudnn.deterministic = True
set_seed(0)
# (func, name, need_postprocess)
avail_attn_funcs = [(ColoAttention.attention, "coloattn", False)]
avail_custom_mask_attn_funcs = [(ColoAttention.attention, "coloattn", False)]
avail_padding_mask_attn_funcs = [(ColoAttention.attention, "coloattn", False)]
for ext_cls in FlashAttentionLoader.REGISTRY:
ext = ext_cls()
if ext.is_available():
ext.assert_compatible()
avail_attn_funcs.append((ext.load(), ext.name, True))
for ext_cls in FlashAttentionWithCustomMaskLoader.REGISTRY:
ext = ext_cls()
if ext.is_available():
ext.assert_compatible()
avail_custom_mask_attn_funcs.append((ext.load(), ext.name, True))
for ext_cls in FlashAttentionWithPaddingMaskLoader.REGISTRY:
ext = ext_cls()
if ext.is_available():
ext.assert_compatible()
avail_padding_mask_attn_funcs.append((ext.load(), ext.name, True))
test_sets = {
"none": (lambda dtype: ({}, None), avail_attn_funcs),
"padded": (gen_padded_kwargs, avail_padding_mask_attn_funcs),
"padded_causal": (gen_padded_causal_kwargs, avail_padding_mask_attn_funcs),
"causal": (gen_causal_kwargs, avail_attn_funcs),
"custom": (gen_custom_kwargs, avail_custom_mask_attn_funcs),
}
for mask_type, (gen_kwargs_func, attn_funcs) in test_sets.items():
attn_kwargs, padding_mask = gen_kwargs_func(dtype)
for attn_func, name, need_postprocess in attn_funcs:
print(f"{dtype}, {name}, {mask_type}")
if need_postprocess:
check_attn_func(dtype, attn_func, post_process_kwargs_for_raw_attn(attn_kwargs), padding_mask)
else:
check_attn_func(dtype, attn_func, attn_kwargs, padding_mask)
if __name__ == "__main__":
test_flash_attn_func()
...@@ -31,6 +31,7 @@ def build_model( ...@@ -31,6 +31,7 @@ def build_model(
enable_jit_fused=False, enable_jit_fused=False,
enable_sequence_parallelism=False, enable_sequence_parallelism=False,
use_lazy_init: bool = False, use_lazy_init: bool = False,
dtype=torch.float32,
): ):
# create new model # create new model
ctx = LazyInitContext() if use_lazy_init else nullcontext() ctx = LazyInitContext() if use_lazy_init else nullcontext()
...@@ -51,7 +52,7 @@ def build_model( ...@@ -51,7 +52,7 @@ def build_model(
model_copy = copy.deepcopy(org_model) model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy) sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model.cuda(), sharded_model.cuda() return org_model.cuda().to(dtype), sharded_model.cuda().to(dtype)
def build_pipeline_model( def build_pipeline_model(
...@@ -132,7 +133,14 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c ...@@ -132,7 +133,14 @@ def build_model_from_hybrid_plugin(model_fn: Callable, loss_fn: Callable, test_c
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion) sharded_model, sharded_optimizer, criterion, _, _ = booster.boost(sharded_model, sharded_optimizer, criterion)
return org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster return (
org_model,
org_optimizer,
sharded_model,
sharded_optimizer,
criterion,
booster,
)
def run_forward_backward_with_hybrid_plugin( def run_forward_backward_with_hybrid_plugin(
...@@ -173,7 +181,12 @@ def run_forward_backward_with_hybrid_plugin( ...@@ -173,7 +181,12 @@ def run_forward_backward_with_hybrid_plugin(
data_iter = iter([data]) data_iter = iter([data])
sharded_output = booster.execute_pipeline( sharded_output = booster.execute_pipeline(
data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=True data_iter,
sharded_model,
_criterion,
sharded_optimizer,
return_loss=True,
return_outputs=True,
) )
sharded_loss = sharded_output["loss"] sharded_loss = sharded_output["loss"]
else: else:
...@@ -313,7 +326,9 @@ def check_grad( ...@@ -313,7 +326,9 @@ def check_grad(
def unwrap_model( def unwrap_model(
module: Module, base_model_class_name: Optional[str] = None, base_model_attribute_name: Optional[str] = None module: Module,
base_model_class_name: Optional[str] = None,
base_model_attribute_name: Optional[str] = None,
): ):
if isinstance(module, HybridParallelModule): if isinstance(module, HybridParallelModule):
module = module.unwrap() module = module.unwrap()
......
...@@ -45,19 +45,51 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo ...@@ -45,19 +45,51 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
"qformer.encoder.layer[0].attention.output.dense", "qformer.encoder.layer[0].attention.output.dense",
"language_model.model.decoder.layers[0].self_attn.out_proj", "language_model.model.decoder.layers[0].self_attn.out_proj",
] ]
check_grad(blip2, sharded_blip2, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False) check_grad(
check_grad(blip2, sharded_blip2, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False) blip2,
sharded_blip2,
col_layer_for_check,
atol=1e-6,
rtol=1e-5,
dim=0,
verbose=False,
)
check_grad(
blip2,
sharded_blip2,
row_layer_for_check,
atol=1e-6,
rtol=1e-5,
dim=1,
verbose=False,
)
@parameterize("enable_fused_normalization", [True, False]) @parameterize("enable_fused_normalization", [True, False])
@parameterize("enable_tensor_parallelism", [True, False]) @parameterize("enable_tensor_parallelism", [True, False])
@parameterize("enable_flash_attention", [True, False]) @parameterize("enable_flash_attention", [True, False])
@parameterize("enable_jit_fused", [True, False]) @parameterize("enable_jit_fused", [True, False])
def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused): def run_blip2_test(
enable_fused_normalization,
enable_tensor_parallelism,
enable_flash_attention,
enable_jit_fused,
):
sub_model_zoo = model_zoo.get_sub_registry("transformers_blip2") sub_model_zoo = model_zoo.get_sub_registry("transformers_blip2")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (
model_fn,
data_gen_fn,
output_transform_fn,
loss_fn,
_,
) in sub_model_zoo.items():
org_model, sharded_model = build_model( org_model, sharded_model = build_model(
model_fn, enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused model_fn,
enable_fused_normalization,
enable_tensor_parallelism,
enable_flash_attention,
enable_jit_fused,
dtype=torch.float,
) )
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
...@@ -66,7 +98,14 @@ def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable ...@@ -66,7 +98,14 @@ def run_blip2_test(enable_fused_normalization, enable_tensor_parallelism, enable
def check_blip2(rank, world_size, port): def check_blip2(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(
config={},
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
run_blip2_test() run_blip2_test()
......
...@@ -11,7 +11,6 @@ from tests.test_shardformer.test_model._utils import ( ...@@ -11,7 +11,6 @@ from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin, build_model_from_hybrid_plugin,
check_all_grad_tensors, check_all_grad_tensors,
check_loss, check_loss,
check_output_hidden_state,
check_weight, check_weight,
get_grad_tensors_for_check, get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin, run_forward_backward_with_hybrid_plugin,
...@@ -25,7 +24,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -25,7 +24,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
) )
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster,
) )
stage_manager = booster.plugin.stage_manager stage_manager = booster.plugin.stage_manager
...@@ -36,7 +41,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -36,7 +41,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
shard_chatglm_model = unwrap_model(sharded_model, "ChatGLMModel", "transformer") shard_chatglm_model = unwrap_model(sharded_model, "ChatGLMModel", "transformer")
norm_layer_for_check = ["encoder.layers[0].input_layernorm"] norm_layer_for_check = ["encoder.layers[0].input_layernorm"]
row_layer_for_check = ["encoder.layers[0].self_attention.query_key_value", "embedding.word_embeddings"] row_layer_for_check = [
"encoder.layers[0].self_attention.query_key_value",
"embedding.word_embeddings",
]
col_layer_for_check = ["encoder.layers[0].self_attention.dense"] col_layer_for_check = ["encoder.layers[0].self_attention.dense"]
# Save gradient tensors for comparison between the original model and the sharded model. # Save gradient tensors for comparison between the original model and the sharded model.
...@@ -94,8 +102,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -94,8 +102,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == "ChatGLMModel": # TODO: ChatGLMModel output is [S, B, H], merging batch of pipeline is wrong
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) # if org_model.__class__.__name__ == "ChatGLMModel":
# check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
...@@ -143,8 +152,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -143,8 +152,20 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
}, },
{"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, {
{"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"}, "tp_size": 4,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
},
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
...@@ -159,7 +180,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -159,7 +180,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
def run_chatglm_test(test_config): def run_chatglm_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_chatglm") sub_model_zoo = model_zoo.get_sub_registry("transformers_chatglm")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (
model_fn,
data_gen_fn,
output_transform_fn,
loss_fn,
_,
) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter() clear_layout_converter()
...@@ -193,7 +220,13 @@ def run_chatglm_test(test_config): ...@@ -193,7 +220,13 @@ def run_chatglm_test(test_config):
def run_chatglm_3d_test(test_config): def run_chatglm_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_chatglm") sub_model_zoo = model_zoo.get_sub_registry("transformers_chatglm")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (
model_fn,
data_gen_fn,
output_transform_fn,
loss_fn,
_,
) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter() clear_layout_converter()
...@@ -202,13 +235,27 @@ def run_chatglm_3d_test(test_config): ...@@ -202,13 +235,27 @@ def run_chatglm_3d_test(test_config):
def check_chatglm(rank, world_size, port): def check_chatglm(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(
config={},
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
run_chatglm_test() run_chatglm_test()
def check_chatglm_3d(rank, world_size, port): def check_chatglm_3d(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(
config={},
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
run_chatglm_3d_test() run_chatglm_3d_test()
......
...@@ -25,7 +25,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -25,7 +25,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
) )
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster,
) )
stage_manager = booster.plugin.stage_manager stage_manager = booster.plugin.stage_manager
...@@ -47,10 +53,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -47,10 +53,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
col_layer_grads = get_grad_tensors_for_check( col_layer_grads = get_grad_tensors_for_check(
gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False gpt2,
sharded_gpt2,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
) )
row_layer_grads = get_grad_tensors_for_check( row_layer_grads = get_grad_tensors_for_check(
gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False gpt2,
sharded_gpt2,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False,
) )
norm_layer_grads = get_grad_tensors_for_check( norm_layer_grads = get_grad_tensors_for_check(
...@@ -90,7 +110,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -90,7 +110,16 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 1e-3 atol, rtol = 5e-3, 1e-3
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) check_weight(
gpt2,
sharded_gpt2,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
# check grads # check grads
check_all_grad_tensors(grads_to_check) check_all_grad_tensors(grads_to_check)
...@@ -123,14 +152,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -123,14 +152,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{ {
"tp_size": 4, "tp_size": 4,
"pp_size": 1, "pp_size": 1,
"enable_all_optimization": True, "enable_all_optimization": False,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
}, },
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
"enable_all_optimization": True, "enable_all_optimization": False,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
}, },
...@@ -138,7 +167,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -138,7 +167,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_all_optimization": True, "enable_all_optimization": False,
"use_lazy_init": True, "use_lazy_init": True,
"precision": "fp32", "precision": "fp32",
}, },
...@@ -167,7 +196,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -167,7 +196,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
def run_gpt2_test(test_config): def run_gpt2_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj") sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (
model_fn,
data_gen_fn,
output_transform_fn,
loss_fn,
_,
) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter() clear_layout_converter()
...@@ -202,7 +237,13 @@ def run_gpt2_test(test_config): ...@@ -202,7 +237,13 @@ def run_gpt2_test(test_config):
def run_gpt2_3d_test(test_config): def run_gpt2_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj") sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (
model_fn,
data_gen_fn,
output_transform_fn,
loss_fn,
_,
) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter() clear_layout_converter()
...@@ -211,13 +252,27 @@ def run_gpt2_3d_test(test_config): ...@@ -211,13 +252,27 @@ def run_gpt2_3d_test(test_config):
def check_gpt2(rank, world_size, port): def check_gpt2(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(
config={},
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
run_gpt2_test() run_gpt2_test()
def check_gpt2_3d(rank, world_size, port): def check_gpt2_3d(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(
config={},
rank=rank,
world_size=world_size,
host="localhost",
port=port,
backend="nccl",
)
run_gpt2_3d_test() run_gpt2_3d_test()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment