Unverified Commit 75af66cd authored by Zhongkai Zhao's avatar Zhongkai Zhao Committed by GitHub
Browse files

[Hotfix] Fix model policy matching strategy in ShardFormer (#5064)

* hotfix/Fix get model policy strategy in ShardFormer

* fix bug in auto policy
parent 4ccb9ded
...@@ -32,7 +32,7 @@ Colossal Inference is composed of three main components: ...@@ -32,7 +32,7 @@ Colossal Inference is composed of three main components:
In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes. In this section we discuss how the colossal inference works and integrates with the `Shardformer` . The details can be found in our codes.
![Colossal-Inference](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/inference-arch.png) <img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/inference/inference-arch.png" alt="Colossal-Inference" style="zoom: 33%;"/>
## Roadmap of our implementation ## Roadmap of our implementation
......
import importlib import importlib
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import torch.nn as nn import torch.nn as nn
from ..shard.shard_config import ShardConfig
from .base_policy import Policy from .base_policy import Policy
__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"] __all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]
...@@ -150,38 +148,11 @@ _POLICY_LIST = { ...@@ -150,38 +148,11 @@ _POLICY_LIST = {
), ),
} }
_INFER_POLICY_LIST = {
# LlaMa
"transformers.models.llama.modeling_llama.LlamaModel": PolicyLocation(
file_name="llama", class_name="LlamaModelInferPolicy"
),
"transformers.models.llama.modeling_llama.LlamaForCausalLM": PolicyLocation(
file_name="llama", class_name="LlamaModelInferPolicy"
),
# Bloom
"transformers.models.bloom.modeling_bloom.BloomModel": PolicyLocation(
file_name="bloom", class_name="BloomModelInferPolicy"
),
"transformers.models.bloom.modeling_bloom.BloomForCausalLM": PolicyLocation(
file_name="bloom", class_name="BloomModelInferPolicy"
),
# ChatGLM2
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(
file_name="chatglm2", class_name="ChatGLM2InferPolicy"
),
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLM2ForConditionalGenerationInferPolicy"
),
}
def import_policy(policy_location: PolicyLocation, inference_only: Optional[bool] = False) -> Policy: def import_policy(policy_location: PolicyLocation) -> Policy:
""" """
Dynamically import a Policy class based on the policy location. Dynamically import a Policy class based on the policy location.
""" """
if inference_only:
module_name = f"colossalai.inference.tensor_parallel.policies.{policy_location.file_name}"
else:
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}" module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
return getattr(module, policy_location.class_name) return getattr(module, policy_location.class_name)
...@@ -198,7 +169,7 @@ def _fullname(obj): ...@@ -198,7 +169,7 @@ def _fullname(obj):
return module + "." + klass.__qualname__ return module + "." + klass.__qualname__
def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy: def get_autopolicy(model: nn.Module) -> Policy:
r""" r"""
Return the auto policy for the model Return the auto policy for the model
...@@ -209,16 +180,12 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy ...@@ -209,16 +180,12 @@ def get_autopolicy(model: nn.Module, shard_config: ShardConfig = None) -> Policy
:class:`Policy`: The auto policy for the model :class:`Policy`: The auto policy for the model
""" """
full_name = _fullname(model) full_name = _fullname(model)
inference_only = shard_config.extra_kwargs.get("inference_only", None)
if inference_only:
policy_location = _INFER_POLICY_LIST.get(full_name, None)
else:
policy_location = _POLICY_LIST.get(full_name, None) policy_location = _POLICY_LIST.get(full_name, None)
if policy_location is None: if policy_location is None:
raise NotImplementedError( raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())} and {list(_INFER_POLICY_LIST.keys())}" f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
) )
else: else:
policy = import_policy(policy_location, inference_only) policy = import_policy(policy_location)
return policy() return policy()
...@@ -28,7 +28,7 @@ class ModelSharder(object): ...@@ -28,7 +28,7 @@ class ModelSharder(object):
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None: def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
self.model = model self.model = model
self.shard_config = shard_config self.shard_config = shard_config
self.policy = get_autopolicy(self.model, shard_config) if policy is None else policy self.policy = get_autopolicy(self.model) if policy is None else policy
def shard(self) -> List[Dict[int, Tensor]]: def shard(self) -> List[Dict[int, Tensor]]:
r""" r"""
......
...@@ -19,7 +19,6 @@ def build_model( ...@@ -19,7 +19,6 @@ def build_model(
enable_tensor_parallelism=enable_tensor_parallelism, enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention, enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused, enable_jit_fused=enable_jit_fused,
extra_kwargs={"inference_only": True},
) )
model_copy = copy.deepcopy(org_model) model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment