"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "7b9b86441fbffdd07021f234ec88d0dbc470fa5c"
Commit 14846934 authored by ver217's avatar ver217
Browse files

Merge branch 'main' into sync/npu

parents 9102d655 5d9a0ae7
from .colo_init_context import ColoInitContext, post_process_colo_init_ctx
from .ophooks import BaseOpHook, register_ophooks_recursively from .ophooks import BaseOpHook, register_ophooks_recursively
from .stateful_tensor import StatefulTensor from .stateful_tensor import StatefulTensor
from .stateful_tensor_mgr import StatefulTensorMgr from .stateful_tensor_mgr import StatefulTensorMgr
...@@ -11,4 +12,6 @@ __all__ = [ ...@@ -11,4 +12,6 @@ __all__ = [
"AutoTensorPlacementPolicy", "AutoTensorPlacementPolicy",
"register_ophooks_recursively", "register_ophooks_recursively",
"BaseOpHook", "BaseOpHook",
"ColoInitContext",
"post_process_colo_init_ctx",
] ]
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -32,7 +32,7 @@ def set_obj_list_element(obj, attr: str, value): ...@@ -32,7 +32,7 @@ def set_obj_list_element(obj, attr: str, value):
r""" r"""
Set the element to value of a list object Set the element to value of a list object
It used like set_obj_list_element(obj, 'lyaers[0]', new_layer), it will set obj.layers[0] to value It used like set_obj_list_element(obj, 'layers[0]', new_layer), it will set obj.layers[0] to value
Args: Args:
obj (object): The object to set obj (object): The object to set
......
This diff is collapsed.
...@@ -408,7 +408,7 @@ class Linear1D_Row(ParallelModule): ...@@ -408,7 +408,7 @@ class Linear1D_Row(ParallelModule):
handle.wait() handle.wait()
output = torch.cat(output_parallel_list, dim=-1) output = torch.cat(output_parallel_list, dim=-1)
else: else:
output_parallel = F.linear(input_, self.weight) output_parallel = linear_with_async_comm(input_, self.weight, None, None, False)
if self.seq_parallel: if self.seq_parallel:
output = linear_reducescatter_forward_gather_backward( output = linear_reducescatter_forward_gather_backward(
output_parallel, self.process_group, self.seq_parallel_dim output_parallel, self.process_group, self.seq_parallel_dim
......
This diff is collapsed.
...@@ -275,8 +275,8 @@ class FusedRMSNorm(BaseLayerNorm): ...@@ -275,8 +275,8 @@ class FusedRMSNorm(BaseLayerNorm):
) )
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
# to check if it is huggingface LlamaRMSNorm # to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
if module.__class__.__name__ == "LlamaRMSNorm": if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]:
normalized_shape = module.weight.shape[0] normalized_shape = module.weight.shape[0]
eps = module.variance_epsilon eps = module.variance_epsilon
elementwise_affine = True elementwise_affine = True
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -21,6 +21,15 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDe ...@@ -21,6 +21,15 @@ from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDe
class BloomPolicy(Policy): class BloomPolicy(Policy):
def __init__(self) -> None:
super().__init__()
import transformers
from packaging.version import Version
assert Version(transformers.__version__) <= Version(
"4.33.0"
), "The Bloom model should run on a transformers version not greater than 4.33.0."
def config_sanity_check(self): def config_sanity_check(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment